From 54887b1777d4187d51a00c35b81e076578529d12 Mon Sep 17 00:00:00 2001 From: Mamy Ratsimbazafy Date: Sat, 6 Feb 2021 22:11:17 +0100 Subject: [PATCH] [Research] KZG polynomial commitment - part 1 FFT (#151) * FFT compiles, now on to debugging ... [skip CI] * Fix FFT and add bench [skip ci] * rename + add KZG resources * rename fft_fr * Implement FFT on elliptic curves =) * FFT G1 bench --- .../arithmetic/finite_fields_inversion.nim | 15 +- research/README.md | 8 + research/kzg_poly_commit/README.md | 12 + research/kzg_poly_commit/fft_fr.nim | 280 ++++++++++++++++ research/kzg_poly_commit/fft_g1.nim | 302 ++++++++++++++++++ research/kzg_poly_commit/fft_lut.nim | 47 +++ research/kzg_poly_commit/strided_views.nim | 247 ++++++++++++++ 7 files changed, 902 insertions(+), 9 deletions(-) create mode 100644 research/README.md create mode 100644 research/kzg_poly_commit/README.md create mode 100644 research/kzg_poly_commit/fft_fr.nim create mode 100644 research/kzg_poly_commit/fft_g1.nim create mode 100644 research/kzg_poly_commit/fft_lut.nim create mode 100644 research/kzg_poly_commit/strided_views.nim diff --git a/constantine/arithmetic/finite_fields_inversion.nim b/constantine/arithmetic/finite_fields_inversion.nim index 6f04bb9..b120c7f 100644 --- a/constantine/arithmetic/finite_fields_inversion.nim +++ b/constantine/arithmetic/finite_fields_inversion.nim @@ -23,13 +23,13 @@ export zoo_inversions {.push raises: [].} {.push inline.} -func inv_euclid*(r: var Fp, a: Fp) = +func inv_euclid*(r: var FF, a: FF) = ## Inversion modulo p via ## Niels Moller constant-time version of ## Stein's GCD derived from extended binary Euclid algorithm - r.mres.steinsGCD(a.mres, Fp.getR2modP(), Fp.C.Mod, Fp.getPrimePlus1div2()) + r.mres.steinsGCD(a.mres, FF.getR2modP(), FF.fieldMod(), FF.getPrimePlus1div2()) -func inv*(r: var Fp, a: Fp) = +func inv*(r: var FF, a: FF) = ## Inversion modulo p ## ## The inverse of 0 is 0. @@ -40,22 +40,19 @@ func inv*(r: var Fp, a: Fp) = # neither for Secp256k1 nor BN curves # Performance is slower than GCD # To be revisited with faster squaring/multiplications - when Fp.C.hasInversionAddchain(): + when FF is Fp and FF.C.hasInversionAddchain(): r.inv_addchain(a) else: r.inv_euclid(a) -func inv*(a: var Fp) = +func inv*(a: var FF) = ## Inversion modulo p ## ## The inverse of 0 is 0. ## Incidentally this avoids extra check ## to convert Jacobian and Projective coordinates ## to affine for elliptic curve - when Fp.C.hasInversionAddchain(): - a.inv_addchain(a) - else: - a.inv_euclid(a) + a.inv(a) {.pop.} # inline {.pop.} # raises no exceptions diff --git a/research/README.md b/research/README.md new file mode 100644 index 0000000..0c0184a --- /dev/null +++ b/research/README.md @@ -0,0 +1,8 @@ +# Research + +This folder stashes experimentations before they are productionized into the library. + +- `kzg`: KZG Polynomial Commitments\ + Constant-Size Commitments to Polynomials and Their Applications\ + Aniket Kate, Gregory M. Zaverucha, Ian Goldberg, 2010\ + https://www.iacr.org/archive/asiacrypt2010/6477178/6477178.pdf diff --git a/research/kzg_poly_commit/README.md b/research/kzg_poly_commit/README.md new file mode 100644 index 0000000..ff4b815 --- /dev/null +++ b/research/kzg_poly_commit/README.md @@ -0,0 +1,12 @@ +# KZG Polynomial Commitment research + +Research for Ethereum 2.0 phase 1 +to implement the Data Availability Sampling protocol + +See + +- https://dankradfeist.de/ethereum/2020/06/16/kate-polynomial-commitments.html +- https://github.com/protolambda/go-kate +- FK20: https://github.com/khovratovich/Kate/blob/master/Kate_amortized.pdf +- https://github.com/ethereum/research/tree/master/polynomial_reconstruction +- https://github.com/ethereum/research/tree/master/kzg_data_availability diff --git a/research/kzg_poly_commit/fft_fr.nim b/research/kzg_poly_commit/fft_fr.nim new file mode 100644 index 0000000..5ef5153 --- /dev/null +++ b/research/kzg_poly_commit/fft_fr.nim @@ -0,0 +1,280 @@ +# Constantine +# Copyright (c) 2018-2019 Status Research & Development GmbH +# Copyright (c) 2020-Present Mamy André-Ratsimbazafy +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +import + ../../constantine/config/curves, + ../../constantine/[arithmetic, primitives], + ../../constantine/io/io_fields, + # Research + ./strided_views, + ./fft_lut + +# See: https://github.com/ethereum/research/blob/master/kzg_data_availability/fft.py +# Quirks of the Python impl: +# - no tests of FFT alone? +# - a lot of "if type(x) == tuple else" +# +# See: https://github.com/protolambda/go-kate/blob/7bb4684/fft_fr.go#L19-L21 +# The go port uses stride+offset to deal with skip iterator. +# +# Other readable FFTs includes: +# - https://github.com/kwantam/fffft +# - https://github.com/ConsenSys/gnark/blob/master/internal/backend/bls381/fft/fft.go +# - https://github.com/poanetwork/threshold_crypto/blob/8820c11/src/poly_vals.rs#L332-L370 +# - https://github.com/zkcrypto/bellman/blob/10c5010/src/domain.rs#L272-L315 +# - Modern Computer Arithmetic, Brent and Zimmermann, p53 algorithm 2.2 +# https://members.loria.fr/PZimmermann/mca/mca-cup-0.5.9.pdf + +# ############################################################ +# +# Finite-Field Fast Fourier Transform +# +# ############################################################ +# +# This is a research, unoptimized implementation of +# Finite Field Fast Fourier Transform + +# In research phase we tolerate using +# - garbage collected types +# - and exceptions for fast prototyping +# +# In particular, in production all signed integers +# must be verified not to overflow +# and should not throw (or use unsigned) + +# FFT Context +# ---------------------------------------------------------------- + +type + FFTStatus = enum + FFTS_Success + FFTS_TooManyValues = "Input length greater than the field 2-adicity (number of roots of unity)" + FFTS_SizeNotPowerOfTwo = "Input must be of a power of 2 length" + + FFTDescriptor[F] = object + ## Metadata for FFT on field F + maxWidth: int + rootOfUnity: F + ## The root of unity that generates all roots + expandedRootsOfUnity: seq[F] + ## domain, starting and ending with 1 + +func isPowerOf2(n: SomeUnsignedInt): bool = + (n and (n - 1)) == 0 + +func nextPowerOf2(n: uint64): uint64 = + ## Returns x if x is a power of 2 + ## or the next biggest power of 2 + 1'u64 shl (log2(n-1) + 1) + +func expandRootOfUnity[F](rootOfUnity: F): seq[F] = + ## From a generator root of unity + ## expand to width + 1 values. + ## (Last value is 1 for the reverse array) + # For a field of order q, there are gcd(n, q−1) + # nth roots of unity, a.k.a. solutions to xⁿ ≡ 1 (mod q) + # but it's likely too long to compute bigint GCD + # so embrace heap (re-)allocations. + # Figuring out how to do to right size the buffers + # in production will be fun. + result.setLen(2) + result[0].setOne() + result[1] = rootOfUnity + while not result[^1].isOne().bool: + result.setLen(result.len + 1) + result[^1].prod(result[^2], rootOfUnity) + +# FFT Algorithm +# ---------------------------------------------------------------- + +# TODO: research Decimation in Time and Decimation in Frequency +# and FFT butterflies + +func simpleFT[F]( + output: var View[F], + vals: View[F], + rootsOfUnity: View[F] + ) = + # FFT is a recursive algorithm + # This is the base-case using a O(n²) algorithm + + let L = output.len + var last {.noInit.}, v {.noInit.}: F + + for i in 0 ..< L: + last.prod(vals[0], rootsOfUnity[0]) + for j in 1 ..< L: + v.prod(vals[j], rootsOfUnity[(i*j) mod L]) + last += v + output[i] = last + +func fft_internal[F]( + output: var View[F], + vals: View[F], + rootsOfUnity: View[F] + ) = + if output.len <= 4: + simpleFT(output, vals, rootsOfUnity) + return + + # Recursive Divide-and-Conquer + let (evenVals, oddVals) = vals.splitAlternate() + var (outLeft, outRight) = output.splitMiddle() + let halfROI = rootsOfUnity.skipHalf() + + fft_internal(outLeft, evenVals, halfROI) + fft_internal(outRight, oddVals, halfROI) + + let half = outLeft.len + var y_times_root{.noinit.}: F + + for i in 0 ..< half: + # FFT Butterfly + y_times_root .prod(output[i+half], rootsOfUnity[i]) + output[i+half] .diff(output[i], y_times_root) + output[i] += y_times_root + +func fft*[F]( + desc: FFTDescriptor[F], + output: var openarray[F], + vals: openarray[F]): FFT_Status = + if vals.len > desc.maxWidth: + return FFTS_TooManyValues + if not vals.len.uint64.isPowerOf2(): + return FFTS_SizeNotPowerOfTwo + + let rootz = desc.expandedRootsOfUnity + .toView() + .slice(0, desc.maxWidth-1, desc.maxWidth div vals.len) + + var voutput = output.toView() + fft_internal(voutput, vals.toView(), rootz) + return FFTS_Success + +func ifft*[F]( + desc: FFTDescriptor[F], + output: var openarray[F], + vals: openarray[F]): FFT_Status = + ## Inverse FFT + if vals.len > desc.maxWidth: + return FFTS_TooManyValues + if not vals.len.uint64.isPowerOf2(): + return FFTS_SizeNotPowerOfTwo + + let rootz = desc.expandedRootsOfUnity + .toView() + .reversed() + .slice(0, desc.maxWidth-1, desc.maxWidth div vals.len) + + var voutput = output.toView() + fft_internal(voutput, vals.toView(), rootz) + + var invLen {.noInit.}: F + invLen.fromUint(vals.len.uint64) + invLen.inv() + + for i in 0..< output.len: + output[i] *= invLen + + return FFTS_Success + +# FFT Descriptor +# ---------------------------------------------------------------- + +proc init*(T: type FFTDescriptor, maxScale: uint8): T = + result.maxWidth = 1 shl maxScale + result.rootOfUnity = scaleToRootOfUnity(T.F.C)[maxScale] + result.expandedRootsOfUnity = + result.rootOfUnity.expandRootOfUnity() + # Aren't you tired of reading about unity? + +# ############################################################ +# +# Sanity checks +# +# ############################################################ + +{.experimental: "views".} + +when isMainModule: + import + std/[times, monotimes, strformat], + ../../helpers/prng_unsafe + + proc roundtrip() = + let fftDesc = FFTDescriptor[Fr[BLS12_381]].init(maxScale = 4) + var data = newSeq[Fr[BLS12_381]](fftDesc.maxWidth) + for i in 0 ..< fftDesc.maxWidth: + data[i].fromUint i.uint64 + + var coefs = newSeq[Fr[BLS12_381]](data.len) + let fftOk = fft(fftDesc, coefs, data) + doAssert fftOk == FFTS_Success + # display("coefs", 0, coefs) + + var res = newSeq[Fr[BLS12_381]](data.len) + let ifftOk = ifft(fftDesc, res, coefs) + doAssert ifftOk == FFTS_Success + # display("res", 0, coefs) + + for i in 0 ..< res.len: + if bool(res[i] != data[i]): + echo "Error: expected ", data[i].toHex(), " but got ", res[i].toHex() + quit 1 + + echo "FFT round-trip check SUCCESS" + + proc warmup() = + # Warmup - make sure cpu is on max perf + let start = cpuTime() + var foo = 123 + for i in 0 ..< 300_000_000: + foo += i*i mod 456 + foo = foo mod 789 + + # Compiler shouldn't optimize away the results as cpuTime rely on sideeffects + let stop = cpuTime() + echo &"Warmup: {stop - start:>4.4f} s, result {foo} (displayed to avoid compiler optimizing warmup away)\n" + + + proc bench() = + echo "Starting benchmark ..." + const NumIters = 100 + + var rng: RngState + rng.seed 0x1234 + # TODO: view types complain about mutable borrow + # in `random_unsafe` due to pseudo view type LimbsViewMut + # (which was views before Nim properly supported them) + + warmup() + + for scale in 4 ..< 16: + # Setup + + let desc = FFTDescriptor[Fr[BLS12_381]].init(uint8 scale) + var data = newSeq[Fr[BLS12_381]](desc.maxWidth) + for i in 0 ..< desc.maxWidth: + # data[i] = rng.random_unsafe(data[i].typeof()) + data[i].fromUint i.uint64 + + var coefsOut = newSeq[Fr[BLS12_381]](data.len) + + # Bench + let start = getMonotime() + for i in 0 ..< NumIters: + let status = desc.fft(coefsOut, data) + doAssert status == FFTS_Success + let stop = getMonotime() + + let ns = inNanoseconds((stop-start) div NumIters) + echo &"FFT scale {scale:>2} {ns:>8} ns/op" + + roundtrip() + warmup() + bench() diff --git a/research/kzg_poly_commit/fft_g1.nim b/research/kzg_poly_commit/fft_g1.nim new file mode 100644 index 0000000..9f7c499 --- /dev/null +++ b/research/kzg_poly_commit/fft_g1.nim @@ -0,0 +1,302 @@ +# Constantine +# Copyright (c) 2018-2019 Status Research & Development GmbH +# Copyright (c) 2020-Present Mamy André-Ratsimbazafy +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +import + ../../constantine/config/curves, + ../../constantine/[arithmetic, primitives], + ../../constantine/elliptic/[ + ec_endomorphism_accel, + ec_shortweierstrass_affine, + ec_shortweierstrass_projective, + ec_shortweierstrass_jacobian, + ], + ../../constantine/io/[io_fields, io_ec], + # Research + ./strided_views, + ./fft_lut + +# See: https://github.com/ethereum/research/blob/master/kzg_data_availability/fft.py +# Quirks of the Python impl: +# - no tests of FFT alone? +# - a lot of "if type(x) == tuple else" +# +# See: https://github.com/protolambda/go-kate/blob/7bb4684/fft_fr.go#L19-L21 +# The go port uses stride+offset to deal with skip iterator. +# +# Other readable FFTs includes: +# - https://github.com/kwantam/fffft +# - https://github.com/ConsenSys/gnark/blob/master/internal/backend/bls381/fft/fft.go +# - https://github.com/poanetwork/threshold_crypto/blob/8820c11/src/poly_vals.rs#L332-L370 +# - https://github.com/zkcrypto/bellman/blob/10c5010/src/domain.rs#L272-L315 +# - Modern Computer Arithmetic, Brent and Zimmermann, p53 algorithm 2.2 +# https://members.loria.fr/PZimmermann/mca/mca-cup-0.5.9.pdf + +# ############################################################ +# +# Finite-Field Fast Fourier Transform +# +# ############################################################ +# +# This is a research, unoptimized implementation of +# Finite Field Fast Fourier Transform + +# In research phase we tolerate using +# - garbage collected types +# - and exceptions for fast prototyping +# +# In particular, in production all signed integers +# must be verified not to overflow +# and should not throw (or use unsigned) + +# FFT Context +# ---------------------------------------------------------------- + +type + FFTStatus = enum + FFTS_Success + FFTS_TooManyValues = "Input length greater than the field 2-adicity (number of roots of unity)" + FFTS_SizeNotPowerOfTwo = "Input must be of a power of 2 length" + + FFTDescriptor[EC] = object + ## Metadata for FFT on Elliptic Curve + maxWidth: int + rootOfUnity: matchingOrderBigInt(EC.F.C) + ## The root of unity that generates all roots + expandedRootsOfUnity: seq[matchingOrderBigInt(EC.F.C)] + ## domain, starting and ending with 1 + +func isPowerOf2(n: SomeUnsignedInt): bool = + (n and (n - 1)) == 0 + +func nextPowerOf2(n: uint64): uint64 = + ## Returns x if x is a power of 2 + ## or the next biggest power of 2 + 1'u64 shl (log2(n-1) + 1) + +func expandRootOfUnity[F](rootOfUnity: F): auto {.noInit.} = + ## From a generator root of unity + ## expand to width + 1 values. + ## (Last value is 1 for the reverse array) + # For a field of order q, there are gcd(n, q−1) + # nth roots of unity, a.k.a. solutions to xⁿ ≡ 1 (mod q) + # but it's likely too long to compute bigint GCD + # so embrace heap (re-)allocations. + # Figuring out how to do to right size the buffers + # in production will be fun. + var r: seq[matchingOrderBigInt(F.C)] + r.setLen(2) + r[0].setOne() + r[1] = rootOfUnity.toBig() + + var cur = rootOfUnity + while not r[^1].isOne().bool: + cur *= rootOfUnity + r.setLen(r.len + 1) + r[^1] = cur.toBig() + + return r + +# FFT Algorithm +# ---------------------------------------------------------------- + +func simpleFT[EC; bits: static int]( + output: var View[EC], + vals: View[EC], + rootsOfUnity: View[BigInt[bits]] + ) = + # FFT is a recursive algorithm + # This is the base-case using a O(n²) algorithm + + let L = output.len + var last {.noInit.}, v {.noInit.}: EC + + for i in 0 ..< L: + last = vals[0] + last.scalarMulGLV_m2w2(rootsOfUnity[0]) + for j in 1 ..< L: + v = vals[j] + v.scalarMulGLV_m2w2(rootsOfUnity[(i*j) mod L]) + last += v + output[i] = last + +func fft_internal[EC; bits: static int]( + output: var View[EC], + vals: View[EC], + rootsOfUnity: View[BigInt[bits]] + ) = + if output.len <= 4: + simpleFT(output, vals, rootsOfUnity) + return + + # Recursive Divide-and-Conquer + let (evenVals, oddVals) = vals.splitAlternate() + var (outLeft, outRight) = output.splitMiddle() + let halfROI = rootsOfUnity.skipHalf() + + fft_internal(outLeft, evenVals, halfROI) + fft_internal(outRight, oddVals, halfROI) + + let half = outLeft.len + var y_times_root{.noinit.}: EC + + for i in 0 ..< half: + # FFT Butterfly + y_times_root = output[i+half] + y_times_root .scalarMulGLV_m2w2(rootsOfUnity[i]) + output[i+half] .diff(output[i], y_times_root) + output[i] += y_times_root + +func fft*[EC]( + desc: FFTDescriptor[EC], + output: var openarray[EC], + vals: openarray[EC]): FFT_Status = + if vals.len > desc.maxWidth: + return FFTS_TooManyValues + if not vals.len.uint64.isPowerOf2(): + return FFTS_SizeNotPowerOfTwo + + let rootz = desc.expandedRootsOfUnity + .toView() + .slice(0, desc.maxWidth-1, desc.maxWidth div vals.len) + + var voutput = output.toView() + fft_internal(voutput, vals.toView(), rootz) + return FFTS_Success + +func ifft*[EC]( + desc: FFTDescriptor[EC], + output: var openarray[EC], + vals: openarray[EC]): FFT_Status = + ## Inverse FFT + if vals.len > desc.maxWidth: + return FFTS_TooManyValues + if not vals.len.uint64.isPowerOf2(): + return FFTS_SizeNotPowerOfTwo + + let rootz = desc.expandedRootsOfUnity + .toView() + .reversed() + .slice(0, desc.maxWidth-1, desc.maxWidth div vals.len) + + var voutput = output.toView() + fft_internal(voutput, vals.toView(), rootz) + + var invLen {.noInit.}: Fr[EC.F.C] + invLen.fromUint(vals.len.uint64) + invLen.inv() + let inv = invLen.toBig() + + for i in 0..< output.len: + output[i].scalarMulGLV_m2w2(inv) + + return FFTS_Success + +# FFT Descriptor +# ---------------------------------------------------------------- + +proc init*(T: type FFTDescriptor, maxScale: uint8): T = + result.maxWidth = 1 shl maxScale + + let root = scaleToRootOfUnity(T.EC.F.C)[maxScale] + result.rootOfUnity = root.toBig() + result.expandedRootsOfUnity = root.expandRootOfUnity() + # Aren't you tired of reading about unity? + +# ############################################################ +# +# Sanity checks +# +# ############################################################ + +{.experimental: "views".} + +when isMainModule: + import + std/[times, monotimes, strformat], + ../../helpers/prng_unsafe + + type G1 = ECP_ShortW_Prj[Fp[BLS12_381], NotOnTwist] + var Generator1: ECP_ShortW_Aff[Fp[BLS12_381], NotOnTwist] + doAssert Generator1.fromHex( + "0x17f1d3a73197d7942695638c4fa9ac0fc3688c4f9774b905a14e3a3f171bac586c55e83ff97a1aeffb3af00adb22c6bb", + "0x08b3f481e3aaa0f1a09e30ed741d8ae4fcf5e095d5d00af600db18cb2c04b3edd03cc744a2888ae40caa232946c5e7e1" + ) + + proc roundtrip() = + let fftDesc = FFTDescriptor[G1].init(maxScale = 4) + var data = newSeq[G1](fftDesc.maxWidth) + data[0].projectiveFromAffine(Generator1) + for i in 1 ..< fftDesc.maxWidth: + data[i].madd(data[i-1], Generator1) + + var coefs = newSeq[G1](data.len) + let fftOk = fft(fftDesc, coefs, data) + doAssert fftOk == FFTS_Success + # display("coefs", 0, coefs) + + var res = newSeq[G1](data.len) + let ifftOk = ifft(fftDesc, res, coefs) + doAssert ifftOk == FFTS_Success + # display("res", 0, coefs) + + for i in 0 ..< res.len: + if bool(res[i] != data[i]): + echo "Error: expected ", data[i].toHex(), " but got ", res[i].toHex() + quit 1 + + echo "FFT round-trip check SUCCESS" + + proc warmup() = + # Warmup - make sure cpu is on max perf + let start = cpuTime() + var foo = 123 + for i in 0 ..< 300_000_000: + foo += i*i mod 456 + foo = foo mod 789 + + # Compiler shouldn't optimize away the results as cpuTime rely on sideeffects + let stop = cpuTime() + echo &"Warmup: {stop - start:>4.4f} s, result {foo} (displayed to avoid compiler optimizing warmup away)\n" + + + proc bench() = + echo "Starting benchmark ..." + const NumIters = 3 + + var rng: RngState + rng.seed 0x1234 + # TODO: view types complain about mutable borrow + # in `random_unsafe` due to pseudo view type LimbsViewMut + # (which was views before Nim properly supported them) + + warmup() + + for scale in 4 ..< 16: + # Setup + + let desc = FFTDescriptor[G1].init(uint8 scale) + var data = newSeq[G1](desc.maxWidth) + data[0].projectiveFromAffine(Generator1) + for i in 1 ..< desc.maxWidth: + data[i].madd(data[i-1], Generator1) + + var coefsOut = newSeq[G1](data.len) + + # Bench + let start = getMonotime() + for i in 0 ..< NumIters: + let status = desc.fft(coefsOut, data) + doAssert status == FFTS_Success + let stop = getMonotime() + + let ns = inNanoseconds((stop-start) div NumIters) + echo &"FFT scale {scale:>2} {ns:>8} ns/op" + + roundtrip() + warmup() + bench() diff --git a/research/kzg_poly_commit/fft_lut.nim b/research/kzg_poly_commit/fft_lut.nim new file mode 100644 index 0000000..3d923d5 --- /dev/null +++ b/research/kzg_poly_commit/fft_lut.nim @@ -0,0 +1,47 @@ +# Constantine +# Copyright (c) 2018-2019 Status Research & Development GmbH +# Copyright (c) 2020-Present Mamy André-Ratsimbazafy +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +import + std/macros, + ../../constantine/config/[curves, common], + ../../constantine/[arithmetic, primitives], + ../../constantine/io/io_fields + +# TODO automate this +# we can precompute everything in Sage +# and auto-generate the file. + +const BLS12_381_Fr_primitive_root = 5 + +func buildRootLUT(F: type Fr): array[32, F] = + ## [pow(PRIMITIVE_ROOT, (MODULUS - 1) // (2**i), MODULUS) for i in range(32)] + + var exponent {.noInit.}: BigInt[F.C.getCurveOrderBitwidth()] + exponent = F.C.getCurveOrder() + exponent -= One + + # Start by the end + var i = result.len - 1 + exponent.shiftRight(i) + result[i].fromUint(BLS12_381_Fr_primitive_root) + result[i].powUnsafeExponent(exponent) + + while i > 0: + result[i-1].square(result[i]) + dec i + + # debugEcho "Fr[BLS12_81] - Roots of Unity:" + # for i in 0 ..< result.len: + # debugEcho " ", i, ": ", result[i].toHex() + # debugEcho "Fr[BLS12_81] - Roots of Unity -- FIN\n" + +let BLS12_381_Fr_ScaleToRootOfUnity* = buildRootLUT(Fr[BLS12_381]) + +{.experimental: "dynamicBindSym".} +macro scaleToRootOfUnity*(C: static Curve): untyped = + return bindSym($C & "_Fr_ScaleToRootOfUnity") diff --git a/research/kzg_poly_commit/strided_views.nim b/research/kzg_poly_commit/strided_views.nim new file mode 100644 index 0000000..979cd19 --- /dev/null +++ b/research/kzg_poly_commit/strided_views.nim @@ -0,0 +1,247 @@ +# Constantine +# Copyright (c) 2018-2019 Status Research & Development GmbH +# Copyright (c) 2020-Present Mamy André-Ratsimbazafy +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +# Strided View - Monodimensional Tensors +# ---------------------------------------------------------------- +# +# FFT uses recursive divide-and-conquer. +# In code this means need strided views +# to enable different logical views of the same memory buffer. +# Strided views are monodimensional tensors: +# See Arraymancer backend: +# https://github.com/mratsim/Arraymancer/blob/71cf616/src/arraymancer/laser/tensor/datatypes.nim#L28-L32 +# Or the minimal tensor implementation challenge: +# https://github.com/SimonDanisch/julia-challenge/blob/b8ed3b6/nim/nim_sol_mratsim.nim#L4-L26 + +{.experimental: "views".} + +type + View*[T] = object + ## A strided view over an (unowned) data buffer + len*: int + stride: int + offset: int + data: lent UncheckedArray[T] + +func `[]`*[T](v: View[T], idx: int): lent T {.inline.} = + v.data[v.offset + idx*v.stride] + +func `[]`*[T](v: var View[T], idx: int): var T {.inline.} = + # Experimental views indeed ... + cast[ptr UncheckedArray[T]](v.data)[v.offset + idx*v.stride] + +func `[]=`*[T](v: var View[T], idx: int, val: T) {.inline.} = + # Experimental views indeed ... + cast[ptr UncheckedArray[T]](v.data)[v.offset + idx*v.stride] = val + +func toView*[T](oa: openArray[T]): View[T] {.inline.} = + result.len = oa.len + result.stride = 1 + result.offset = 0 + result.data = cast[lent UncheckedArray[T]](oa[0].unsafeAddr) + +iterator items*[T](v: View[T]): lent T = + var cur = v.offset + for _ in 0 ..< v.len: + yield v.data[cur] + cur += v.stride + +func `$`*(v: View): string = + result = "View[" + var first = true + for elem in v: + if not first: + result &= ", " + else: + first = false + result &= $elem + result &= ']' + +func toHex*(v: View): string = + mixin toHex + + result = "View[" + var first = true + for elem in v: + if not first: + result &= ", " + else: + first = false + result &= elem.toHex() + result &= ']' + +# FFT-specific splitting +# ------------------------------------------------------------------------------- + +func splitAlternate*(t: View): tuple[even, odd: View] {.inline.} = + ## Split the tensor into 2 + ## partitioning the input every other index + ## even: indices [0, 2, 4, ...] + ## odd: indices [ 1, 3, 5, ...] + assert (t.len and 1) == 0, "The tensor must contain an even number of elements" + + let half = t.len shr 1 + let skipHalf = t.stride shl 1 + + result.even.len = half + result.even.stride = skipHalf + result.even.offset = t.offset + result.even.data = t.data + + result.odd.len = half + result.odd.stride = skipHalf + result.odd.offset = t.offset + t.stride + result.odd.data = t.data + +func splitMiddle*(t: View): tuple[left, right: View] {.inline.} = + ## Split the tensor into 2 + ## partitioning into left and right halves. + ## left: indices [0, 1, 2, 3] + ## right: indices [4, 5, 6, 7] + assert (t.len and 1) == 0, "The tensor must contain an even number of elements" + + let half = t.len shr 1 + + result.left.len = half + result.left.stride = t.stride + result.left.offset = t.offset + result.left.data = t.data + + result.right.len = half + result.right.stride = t.stride + result.right.offset = t.offset + half + result.right.data = t.data + +func skipHalf*(t: View): View {.inline.} = + ## Pick one every other indices + ## output: [0, 2, 4, ...] + assert (t.len and 1) == 0, "The tensor must contain an even number of elements" + + result.len = t.len shr 1 + result.stride = t.stride shl 1 + result.offset = t.offset + result.data = t.data + +func slice*(v: View, start, stop, step: int): View {.inline.} = + ## Slice a view + ## stop is inclusive + # General tensor slicing algorithm is + # https://github.com/mratsim/Arraymancer/blob/71cf616/src/arraymancer/tensor/private/p_accessors_macros_read.nim#L26-L56 + # + # for i, slice in slices: + # # Check if we start from the end + # let a = if slice.a_from_end: result.shape[i] - slice.a + # else: slice.a + # + # let b = if slice.b_from_end: result.shape[i] - slice.b + # else: slice.b + # + # # Compute offset: + # result.offset += a * result.strides[i] + # # Now change shape and strides + # result.strides[i] *= slice.step + # result.shape[i] = abs((b-a) div slice.step) + 1 + # + # with slices being of size 1, as we have a monodimensional Tensor + # and the slice being a.. 0 + # + # result is preinitialized with a copy of v (shape, stride, offset, data) + result.offset = v.offset + start * v.stride + result.stride = v.stride * step + result.len = abs((stop-start) div step) + 1 + result.data = v.data + +func reversed*(v: View): View {.inline.} = + # Hopefully the compiler optimizes div by -1 + v.slice(v.len-1, 0, -1) + +# ############################################################ +# +# Debugging helpers +# +# ############################################################ +import strformat, strutils + +func display*[F](name: string, indent: int, oa: openArray[F]) = + debugEcho indent(name & ", openarray of " & $F & " of length " & $oa.len, indent) + for i in 0 ..< oa.len: + debugEcho indent(&" {i:>2}: {oa[i].toHex()}", indent) + debugEcho indent(name & " " & $F & " -- FIN\n", indent) + +func display*[F](name: string, indent: int, v: View[F]) = + debugEcho indent(name & ", view of " & $F & " of length " & $v.len, indent) + for i in 0 ..< v.len: + debugEcho indent(&" {i:>2}: {v[i].toHex()}", indent) + debugEcho indent(name & " " & $F & " -- FIN\n", indent) + +# ############################################################ +# +# Sanity checks +# +# ############################################################ + +when isMainModule: + proc main() = + var x = [0, 1, 2, 3, 4, 5, 6, 7] + let v = x.toView() + + echo "view: ", v + echo "reversed: ", v.reversed() + + block: + let (even, odd) = v.splitAlternate() + echo "\nSplit Alternate" + echo "----------------" + echo "even: ", even + echo "odd: ", odd + + block: + let (ee, eo) = even.splitAlternate() + echo "" + echo "even-even: ", ee + echo "even-odd: ", eo + echo "even-even rev: ", ee.reversed() + echo "even-odd rev: ", eo.reversed() + + block: + let (oe, oo) = odd.splitAlternate() + echo "" + echo "odd-even: ", oe + echo "odd-odd: ", oo + echo "odd-even rev: ", oe.reversed() + echo "odd-odd rev: ", oo.reversed() + + echo "\nSkip Half" + echo "----------------" + echo "skipHalf: ", v.skipHalf() + echo "skipQuad: ", v.skipHalf().skipHalf() + echo "skipQuad rev: ", v.skipHalf().skipHalf().reversed() + + echo "\nSplit middle" + echo "----------------" + block: + let (left, right) = v.splitMiddle() + echo "left: ", left + echo "right: ", right + block: + let (ll, lr) = left.splitMiddle() + echo "" + echo "left-left: ", ll + echo "left-right: ", lr + echo "left-left rev: ", ll.reversed() + echo "left-right rev: ", lr.reversed() + + block: + let (rl, rr) = right.splitMiddle() + echo "" + echo "right-left: ", rl + echo "right-right: ", rr + echo "right-left rev: ", rl.reversed() + echo "right-right rev: ", rr.reversed() + + main()