mirror of
https://github.com/logos-storage/nim-groth16.git
synced 2026-01-05 07:03:09 +00:00
polynomials and NTT
This commit is contained in:
parent
ba04191b72
commit
30ebd2793e
353
bn128.nim
353
bn128.nim
@ -1,6 +1,8 @@
|
||||
|
||||
import std/bitops
|
||||
import std/strutils
|
||||
import std/streams
|
||||
import std/random
|
||||
|
||||
import constantine/math/arithmetic
|
||||
import constantine/math/io/io_fields
|
||||
@ -8,27 +10,44 @@ import constantine/math/io/io_bigints
|
||||
import constantine/math/config/curves
|
||||
import constantine/math/config/type_ff as tff
|
||||
|
||||
import constantine/math/extension_fields/towers as ext
|
||||
import constantine/math/elliptic/ec_shortweierstrass_affine as ell
|
||||
import constantine/math/extension_fields/towers as ext
|
||||
import constantine/math/elliptic/ec_shortweierstrass_affine as aff
|
||||
import constantine/math/elliptic/ec_shortweierstrass_projective as prj
|
||||
import constantine/math/pairings/pairings_bn as ate
|
||||
import constantine/math/elliptic/ec_multi_scalar_mul as msm
|
||||
import constantine/math/elliptic/ec_scalar_mul as scl
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
type B* = BigInt[256]
|
||||
type Fr* = tff.Fr[BN254Snarks]
|
||||
type Fp* = tff.Fp[BN254Snarks]
|
||||
type Fp2* = ext.QuadraticExt[Fp]
|
||||
type G1* = ell.ECP_ShortW_Aff[Fp , ell.G1]
|
||||
type G2* = ell.ECP_ShortW_Aff[Fp2, ell.G2]
|
||||
type B* = BigInt[256]
|
||||
type Fr* = tff.Fr[BN254Snarks]
|
||||
type Fp* = tff.Fp[BN254Snarks]
|
||||
|
||||
func mkFp2(i: Fp, u: Fp) : Fp2 =
|
||||
type Fp2* = ext.QuadraticExt[Fp]
|
||||
type Fp12* = ext.Fp12[BN254Snarks]
|
||||
|
||||
type G1* = aff.ECP_ShortW_Aff[Fp , aff.G1]
|
||||
type G2* = aff.ECP_ShortW_Aff[Fp2, aff.G2]
|
||||
|
||||
type ProjG1* = prj.ECP_ShortW_Prj[Fp , prj.G1]
|
||||
type ProjG2* = prj.ECP_ShortW_Prj[Fp2, prj.G2]
|
||||
|
||||
func mkFp2* (i: Fp, u: Fp) : Fp2 =
|
||||
let c : array[2, Fp] = [i,u]
|
||||
return ext.QuadraticExt[Fp]( coords: c )
|
||||
|
||||
func unsafeMkG1( X, Y: Fp ) : G1 =
|
||||
return ell.ECP_ShortW_Aff[Fp, ell.G1](x: X, y: Y)
|
||||
func unsafeMkG1* ( X, Y: Fp ) : G1 =
|
||||
return aff.ECP_ShortW_Aff[Fp, aff.G1](x: X, y: Y)
|
||||
|
||||
func unsafeMkG2( X, Y: Fp2 ) : G2 =
|
||||
return ell.ECP_ShortW_Aff[Fp2, ell.G2](x: X, y: Y)
|
||||
func unsafeMkG2* ( X, Y: Fp2 ) : G2 =
|
||||
return aff.ECP_ShortW_Aff[Fp2, aff.G2](x: X, y: Y)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func pairing* (p: G1, q: G2) : Fp12 =
|
||||
var t : Fp12
|
||||
pairing_bn( t, p, q )
|
||||
return t
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
@ -37,29 +56,70 @@ const primeR* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d2833e84879b97
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
const zeroFp* : Fp = fromHex( Fp, "0x00" )
|
||||
const zeroFr* : Fr = fromHex( Fr, "0x00" )
|
||||
const oneFp* : Fp = fromHex( Fp, "0x01" )
|
||||
const oneFr* : Fr = fromHex( Fr, "0x01" )
|
||||
const zeroFp* : Fp = fromHex( Fp, "0x00" )
|
||||
const zeroFr* : Fr = fromHex( Fr, "0x00" )
|
||||
const oneFp* : Fp = fromHex( Fp, "0x01" )
|
||||
const oneFr* : Fr = fromHex( Fr, "0x01" )
|
||||
|
||||
const zeroFp2* : Fp2 = mkFp2( zeroFp, zeroFp )
|
||||
const oneFp2* : Fp2 = mkFp2( oneFp , zeroFp )
|
||||
|
||||
const infG1* : G1 = unsafeMkG1( zeroFp , zeroFp )
|
||||
const infG2* : G2 = unsafeMkG2( zeroFp2 , zeroFp2 )
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func intToFp*(a: int) : Fp =
|
||||
func intToFp*(a: int): Fp =
|
||||
var y : Fp
|
||||
y.fromInt(a)
|
||||
return y
|
||||
|
||||
func intToFr*(a: int) : Fr =
|
||||
func intToFr*(a: int): Fr =
|
||||
var y : Fr
|
||||
y.fromInt(a)
|
||||
return y
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func isZeroFp*(x: Fp): bool = bool(isZero(x))
|
||||
func isZeroFr*(x: Fr): bool = bool(isZero(x))
|
||||
|
||||
func isEquaLFp*(x, y: Fp): bool = bool(x == y)
|
||||
func isEquaLFr*(x, y: Fr): bool = bool(x == y)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func `+`*(x, y: Fp): Fp = ( var z : Fp = x ; z += y ; return z )
|
||||
func `+`*(x, y: Fr): Fr = ( var z : Fr = x ; z += y ; return z )
|
||||
|
||||
func `-`*(x, y: Fp): Fp = ( var z : Fp = x ; z -= y ; return z )
|
||||
func `-`*(x, y: Fr): Fr = ( var z : Fr = x ; z -= y ; return z )
|
||||
|
||||
func `*`*(x, y: Fp): Fp = ( var z : Fp = x ; z *= y ; return z )
|
||||
func `*`*(x, y: Fr): Fr = ( var z : Fr = x ; z *= y ; return z )
|
||||
|
||||
func negFp*(y: Fp): Fp = ( var z : Fp = zeroFp ; z -= y ; return z )
|
||||
func negFr*(y: Fr): Fr = ( var z : Fr = zeroFr ; z -= y ; return z )
|
||||
|
||||
func invFp*(y: Fp): Fp = ( var z : Fp = y ; inv(z) ; return z )
|
||||
func invFr*(y: Fr): Fr = ( var z : Fr = y ; inv(z) ; return z )
|
||||
|
||||
# template/generic instantiation of `pow_vartime` from here
|
||||
# /Users/bkomuves/.nimble/pkgs/constantine-0.0.1/constantine/math/arithmetic/finite_fields.nim(389, 7) template/generic instantiation of `fieldMod` from here
|
||||
# /Users/bkomuves/.nimble/pkgs/constantine-0.0.1/constantine/math/config/curves_prop_field_derived.nim(67, 5) Error: undeclared identifier: 'getCurveOrder'
|
||||
# ...
|
||||
func smallPowFr*(base: Fr, expo: uint): Fr =
|
||||
var a : Fr = oneFr
|
||||
var s : Fr = base
|
||||
var e : uint = expo
|
||||
while (e > 0):
|
||||
if bitand(e,1) > 0: a *= s
|
||||
e = (e shr 1)
|
||||
square(s)
|
||||
return a
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func toDecimalBig*[n](a : BigInt[n]): string =
|
||||
var s : string = toDecimal(a)
|
||||
s = s.strip( leading=true, trailing=false, chars={'0'} )
|
||||
@ -80,16 +140,93 @@ func toDecimalFr*(a : Fr): string =
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
proc debugPrintSeqFr*(msg: string, xs: seq[Fr]) =
|
||||
echo "---------------------"
|
||||
echo msg
|
||||
for x in xs:
|
||||
echo(" " & toDecimalFr(x))
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# random values
|
||||
|
||||
var randomInitialized : bool = false
|
||||
var randomState : Rand = initRand( 12345 )
|
||||
|
||||
proc rndUint64() : uint64 =
|
||||
return randomState.next()
|
||||
|
||||
proc initializeRandomIfNecessary() =
|
||||
if not randomInitialized:
|
||||
randomState = initRand()
|
||||
randomInitialized = true
|
||||
|
||||
#----------------------------| 01234567890abcdf01234567890abcdf01234567890abcdf01234567890abcdf
|
||||
const m64 : B = fromHex( B, "0x0000000000000000000000000000000000000000000000010000000000000000", bigEndian )
|
||||
const m128 : B = fromHex( B, "0x0000000000000000000000000000000100000000000000000000000000000000", bigEndian )
|
||||
const m192 : B = fromHex( B, "0x0000000000000001000000000000000000000000000000000000000000000000", bigEndian )
|
||||
#----------------------------| 01234567890abcdf01234567890abcdf01234567890abcdf01234567890abcdf
|
||||
|
||||
proc randBig*[bits: static int](): BigInt[bits] =
|
||||
|
||||
initializeRandomIfNecessary()
|
||||
|
||||
let a0 : uint64 = rndUint64()
|
||||
let a1 : uint64 = rndUint64()
|
||||
let a2 : uint64 = rndUint64()
|
||||
let a3 : uint64 = rndUint64()
|
||||
|
||||
# echo((a0,a1,a2,a3))
|
||||
|
||||
var b0 : BigInt[bits] ; b0.fromUint(a0)
|
||||
var b1 : BigInt[bits] ; b1.fromUint(a1)
|
||||
var b2 : BigInt[bits] ; b2.fromUint(a2)
|
||||
var b3 : BigInt[bits] ; b3.fromUint(a3)
|
||||
|
||||
# constantine doesn't appear to have left shift....
|
||||
var c1,c2,c3 : BigInt[bits]
|
||||
prod( c1 , b1 , m64 )
|
||||
prod( c2 , b2 , m128 )
|
||||
prod( c3 , b3 , m192 )
|
||||
|
||||
var d : BigInt[bits]
|
||||
d = b0
|
||||
d += c1
|
||||
d += c2
|
||||
d += c3
|
||||
|
||||
return d
|
||||
|
||||
proc randFr*(): Fr =
|
||||
let b : BigInt[254] = randBig[254]()
|
||||
var y : Fr
|
||||
y.fromBig( b )
|
||||
return y
|
||||
|
||||
proc testRandom*() =
|
||||
for i in 1..20:
|
||||
let x = randFr()
|
||||
echo(x.toHex())
|
||||
echo("-------------------")
|
||||
echo(primeR.toHex())
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func checkCurveEqG1*( x, y: Fp ) : bool =
|
||||
var x2 : Fp = x ; square(x2);
|
||||
var y2 : Fp = y ; square(y2);
|
||||
var x3 : Fp = x2 ; x3 *= x;
|
||||
var eq : Fp
|
||||
eq = x3
|
||||
eq += intToFp(3)
|
||||
eq -= y2
|
||||
# echo("eq = ",toDecimalFp(eq))
|
||||
return (bool(isZero(eq)))
|
||||
if bool(isZero(x)) and bool(isZero(y)):
|
||||
# the point at infinity is on the curve by definition
|
||||
return true
|
||||
else:
|
||||
var x2 : Fp = x ; square(x2);
|
||||
var y2 : Fp = y ; square(y2);
|
||||
var x3 : Fp = x2 ; x3 *= x;
|
||||
var eq : Fp
|
||||
eq = x3
|
||||
eq += intToFp(3)
|
||||
eq -= y2
|
||||
# echo("eq = ",toDecimalFp(eq))
|
||||
return (bool(isZero(eq)))
|
||||
|
||||
#---------------------------------------
|
||||
|
||||
# y^2 = x^3 + B
|
||||
# B = b1 + bu*u
|
||||
@ -100,14 +237,18 @@ const twistCoeffB_u : Fp = fromHex(Fp, "0x009713b03af0fed4cd2cafadeed8fdf4a74fa
|
||||
const twistCoeffB : Fp2 = mkFp2( twistCoeffB_1 , twistCoeffB_u )
|
||||
|
||||
func checkCurveEqG2*( x, y: Fp2 ) : bool =
|
||||
var x2 : Fp2 = x ; square(x2);
|
||||
var y2 : Fp2 = y ; square(y2);
|
||||
var x3 : Fp2 = x2 ; x3 *= x;
|
||||
var eq : Fp2
|
||||
eq = x3
|
||||
eq += twistCoeffB
|
||||
eq -= y2
|
||||
return (bool(isZero(eq)))
|
||||
if bool(isZero(x)) and bool(isZero(y)):
|
||||
# the point at infinity is on the curve by definition
|
||||
return true
|
||||
else:
|
||||
var x2 : Fp2 = x ; square(x2);
|
||||
var y2 : Fp2 = y ; square(y2);
|
||||
var x3 : Fp2 = x2 ; x3 *= x;
|
||||
var eq : Fp2
|
||||
eq = x3
|
||||
eq += twistCoeffB
|
||||
eq -= y2
|
||||
return (bool(isZero(eq)))
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
@ -125,6 +266,14 @@ func mkG2( x, y: Fp2 ) : G2 =
|
||||
assert( checkCurveEqG2(x,y) , "mkG2: not a G2 curve point" )
|
||||
return unsafeMkG2(x,y)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func isOnCurveG1* ( p: G1 ) : bool =
|
||||
return checkCurveEqG1( p.x, p.y )
|
||||
|
||||
func isOnCurveG2* ( p: G2 ) : bool =
|
||||
return checkCurveEqG2( p.x, p.y )
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# Dealing with Montgomery representation
|
||||
#
|
||||
@ -281,4 +430,136 @@ proc loadPointsG2*( len: int, stream: Stream ) : seq[G2] =
|
||||
points.add( loadPointG2(stream) )
|
||||
return points
|
||||
|
||||
#===============================================================================
|
||||
|
||||
func addG1*(p,q: G1): G1 =
|
||||
var r, x, y : ProjG1
|
||||
prj.fromAffine(x, p)
|
||||
prj.fromAffine(y, q)
|
||||
prj.sum(r, x, y)
|
||||
var s : G1
|
||||
prj.affine(s, r)
|
||||
return s
|
||||
|
||||
func addG2*(p,q: G2): G2 =
|
||||
var r, x, y : ProjG2
|
||||
prj.fromAffine(x, p)
|
||||
prj.fromAffine(y, q)
|
||||
prj.sum(r, x, y)
|
||||
var s : G2
|
||||
prj.affine(s, r)
|
||||
return s
|
||||
|
||||
func negG1*(p: G1): G1 =
|
||||
var r : G1 = p
|
||||
neg(r)
|
||||
return r
|
||||
|
||||
func negG2*(p: G2): G2 =
|
||||
var r : G2 = p
|
||||
neg(r)
|
||||
return r
|
||||
|
||||
func `+`*(p,q: G1): G1 = addG1(p,q)
|
||||
func `+`*(p,q: G2): G2 = addG2(p,q)
|
||||
|
||||
func `+=`*(p: var G1, q: G1) = p = addG1(p,q)
|
||||
func `+=`*(p: var G2, q: G2) = p = addG2(p,q)
|
||||
|
||||
func `-=`*(p: var G1, q: G1) = p = addG1(p,negG1(q))
|
||||
func `-=`*(p: var G2, q: G2) = p = addG2(p,negG2(q))
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
# func msmG1( coeffs: seq[Fr] , points: seq[G1] ): G1 =
|
||||
#
|
||||
# let N = coeffs.len
|
||||
# assert( N == points.len, "incompatible sequence lengths" )
|
||||
#
|
||||
# # var arr1 = toOpenArray(coeffs, 0, N-1)
|
||||
# # var arr2 = toOpenArray(points, 0, N-1)
|
||||
#
|
||||
# var bigcfs : seq[BigInt[254]]
|
||||
# for x in coeffs:
|
||||
# bigcfs.add( x.toBig() )
|
||||
#
|
||||
# var r : G1
|
||||
#
|
||||
# # [Fp,aff.G1]
|
||||
# msm.multiScalarMul_vartime( r,
|
||||
# toOpenArray(bigcfs, 0, N-1),
|
||||
# toOpenArray(points, 0, N-1) )
|
||||
#
|
||||
# return r
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
#
|
||||
# (affine) scalar multiplication
|
||||
#
|
||||
|
||||
func `**`*( coeff: Fr , point: G1 ) : G1 =
|
||||
var q : ProjG1
|
||||
prj.fromAffine( q , point )
|
||||
scl.scalarMul( q , coeff.toBig() )
|
||||
var r : G1
|
||||
prj.affine( r, q )
|
||||
return r
|
||||
|
||||
func `**`*( coeff: Fr , point: G2 ) : G2 =
|
||||
var q : ProjG2
|
||||
prj.fromAffine( q , point )
|
||||
scl.scalarMul( q , coeff.toBig() )
|
||||
var r : G2
|
||||
prj.affine( r, q )
|
||||
return r
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func msmNaiveG1( coeffs: seq[Fr] , points: seq[G1] ): G1 =
|
||||
let N = coeffs.len
|
||||
assert( N == points.len, "incompatible sequence lengths" )
|
||||
|
||||
var s : ProjG1
|
||||
s.setInf()
|
||||
|
||||
for i in 0..<N:
|
||||
var t : ProjG1
|
||||
prj.fromAffine( t, points[i] )
|
||||
scl.scalarMul( t , coeffs[i].toBig() )
|
||||
s += t
|
||||
|
||||
var r : G1
|
||||
prj.affine( r, s )
|
||||
|
||||
return r
|
||||
|
||||
#---------------------------------------
|
||||
|
||||
func msmNaiveG2( coeffs: seq[Fr] , points: seq[G2] ): G2 =
|
||||
let N = coeffs.len
|
||||
assert( N == points.len, "incompatible sequence lengths" )
|
||||
|
||||
var s : ProjG2
|
||||
s.setInf()
|
||||
|
||||
for i in 0..<N:
|
||||
var t : ProjG2
|
||||
prj.fromAffine( t, points[i] )
|
||||
scl.scalarMul( t , coeffs[i].toBig() )
|
||||
s += t
|
||||
|
||||
var r : G2
|
||||
prj.affine( r, s)
|
||||
|
||||
return r
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
# TODO: proper MSM implementation (couldn't make constantine work at first...)
|
||||
func msmG1*( coeffs: seq[Fr] , points: seq[G1] ): G1 =
|
||||
return msmNaiveG1( coeffs, points )
|
||||
|
||||
func msmG2*( coeffs: seq[Fr] , points: seq[G2] ): G2 =
|
||||
return msmNaiveG2( coeffs, points )
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
51
domain.nim
Normal file
51
domain.nim
Normal file
@ -0,0 +1,51 @@
|
||||
|
||||
#
|
||||
# power-of-two sized multiplicative FFT domains in the scalar field
|
||||
#
|
||||
|
||||
import constantine/math/io/io_bigints
|
||||
import constantine/math/arithmetic except Fp,Fr
|
||||
import constantine/math/io/io_fields except Fp,Fr
|
||||
|
||||
import ./bn128
|
||||
import ./misc
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
type
|
||||
Domain* = object
|
||||
domainSize* : int
|
||||
logDomainSize* : int
|
||||
domainGen* : Fr
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
# the generator of the multiplicative subgroup with size `2^28`
|
||||
const gen28 : Fr = fromHex( Fr, "0x2a3c09f0a58a7e8500e0a7eb8ef62abc402d111e41112ed49bd61b6e725b19f0" )
|
||||
|
||||
func createDomain*(size: int): Domain =
|
||||
let log2 = ceilingLog2(size)
|
||||
assert( (1 shl log2) == size , "domain must have a power-of-two size" )
|
||||
|
||||
let expo : uint = 1'u shl (28 - log2)
|
||||
let gen : Fr = smallPowFr(gen28, expo)
|
||||
|
||||
let halfSize = size div 2
|
||||
let a : Fr = smallPowFr(gen, uint(size ))
|
||||
let b : Fr = smallPowFr(gen, uint(halfSize))
|
||||
assert( bool(a == oneFr) , "domain generator sanity check /A" )
|
||||
assert( not bool(b == oneFr) , "domain generator sanity check /B" )
|
||||
|
||||
return Domain( domainSize:size, logDomainSize:log2, domainGen:gen )
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func enumerateDomain*(D: Domain): seq[Fr] =
|
||||
var xs : seq[Fr] = newSeq[Fr](D.domainSize)
|
||||
var g : Fr = oneFr
|
||||
for i in 0..<D.domainSize:
|
||||
xs[i] = g
|
||||
g *= D.domainGen
|
||||
return xs
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
133
ntt.nim
Normal file
133
ntt.nim
Normal file
@ -0,0 +1,133 @@
|
||||
|
||||
#
|
||||
# Number-theoretic transform
|
||||
#
|
||||
|
||||
import std/sequtils
|
||||
|
||||
import constantine/math/arithmetic except Fp,Fr
|
||||
import constantine/math/io/io_fields
|
||||
|
||||
import bn128
|
||||
import domain
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func forwardNTT_worker( m: int
|
||||
, srcStride: int
|
||||
, gen: Fr
|
||||
, src: seq[Fr] , srcOfs: int
|
||||
, buf: var seq[Fr] , bufOfs: int
|
||||
, tgt: var seq[Fr] , tgtOfs: int ) =
|
||||
case m
|
||||
|
||||
of 0:
|
||||
tgt[tgtOfs] = src[srcOfs]
|
||||
|
||||
of 1:
|
||||
tgt[tgtOfs ] = src[srcOfs] + src[srcOfs+srcStride]
|
||||
tgt[tgtOfs+1] = src[srcOfs] - src[srcOfs+srcStride]
|
||||
|
||||
else:
|
||||
let N : int = 1 shl m
|
||||
let halfN : int = 1 shl (m-1)
|
||||
var gpow : Fr = gen
|
||||
square(gpow)
|
||||
forwardNTT_worker( m-1
|
||||
, srcStride shl 1
|
||||
, gpow
|
||||
, src , srcOfs
|
||||
, buf , bufOfs + N
|
||||
, buf , bufOfs )
|
||||
forwardNTT_worker( m-1
|
||||
, srcStride shl 1
|
||||
, gpow
|
||||
, src , srcOfs + srcStride
|
||||
, buf , bufOfs + N
|
||||
, buf , bufOfs + halfN )
|
||||
gpow = oneFr
|
||||
for j in 0..<halfN:
|
||||
let y : Fr = gpow * buf[bufOfs+j+halfN]
|
||||
tgt[tgtOfs+j ] = buf[bufOfs+j] + y
|
||||
tgt[tgtOfs+j+halfN] = buf[bufOfs+j] - y
|
||||
gpow *= gen
|
||||
|
||||
#---------------------------------------
|
||||
|
||||
# forward number-theoretical transform (corresponds to polynomial evaluation)
|
||||
func forwardNTT*(src: seq[Fr], D: Domain): seq[Fr] =
|
||||
assert( D.domainSize == (1 shl D.logDomainSize) , "domain must have a power-of-two size" )
|
||||
assert( D.domainSize == src.len , "input must have the same size as the domain" )
|
||||
var buf : seq[Fr] = newSeq[Fr]( 2 * D.domainSize )
|
||||
var tgt : seq[Fr] = newSeq[Fr]( D.domainSize )
|
||||
forwardNTT_worker( D.logDomainSize
|
||||
, 1
|
||||
, D.domainGen
|
||||
, src , 0
|
||||
, buf , 0
|
||||
, tgt , 0 )
|
||||
return tgt
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
const oneHalfFr* : Fr = fromHex(Fr, "0x183227397098d014dc2822db40c0ac2e9419f4243cdcb848a1f0fac9f8000001")
|
||||
|
||||
func inverseNTT_worker( m: int
|
||||
, tgtStride: int
|
||||
, gen: Fr
|
||||
, src: seq[Fr] , srcOfs: int
|
||||
, buf: var seq[Fr] , bufOfs: int
|
||||
, tgt: var seq[Fr] , tgtOfs: int ) =
|
||||
case m
|
||||
|
||||
of 0:
|
||||
tgt[tgtOfs] = src[srcOfs]
|
||||
|
||||
of 1:
|
||||
tgt[tgtOfs ] = ( src[srcOfs] + src[srcOfs+1] ) * oneHalfFr
|
||||
tgt[tgtOfs+tgtStride] = ( src[srcOfs] - src[srcOfs+1] ) * oneHalfFr
|
||||
|
||||
else:
|
||||
let N : int = 1 shl m
|
||||
let halfN : int = 1 shl (m-1)
|
||||
let ginv : Fr = invFr(gen)
|
||||
var gpow : Fr = oneHalfFr
|
||||
|
||||
for j in 0..<halfN:
|
||||
buf[bufOfs+j ] = ( src[srcOfs+j] + src[srcOfs+j+halfN] ) * oneHalfFr
|
||||
buf[bufOfs+j+halfN] = ( src[srcOfs+j] - src[srcOfs+j+halfN] ) * gpow
|
||||
gpow *= ginv
|
||||
|
||||
gpow = gen
|
||||
square(gpow)
|
||||
|
||||
inverseNTT_worker( m-1
|
||||
, tgtStride shl 1
|
||||
, gpow
|
||||
, buf , bufOfs
|
||||
, buf , bufOfs + N
|
||||
, tgt , tgtOfs )
|
||||
inverseNTT_worker( m-1
|
||||
, tgtStride shl 1
|
||||
, gpow
|
||||
, buf , bufOfs + halfN
|
||||
, buf , bufOfs + N
|
||||
, tgt , tgtOfs + tgtStride )
|
||||
|
||||
#---------------------------------------
|
||||
|
||||
# inverse number-theoretical transform (corresponds to polynomial interpolation)
|
||||
func inverseNTT*(src: seq[Fr], D: Domain): seq[Fr] =
|
||||
assert( D.domainSize == (1 shl D.logDomainSize) , "domain must have a power-of-two size" )
|
||||
assert( D.domainSize == src.len , "input must have the same size as the domain" )
|
||||
var buf : seq[Fr] = newSeq[Fr]( 2 * D.domainSize )
|
||||
var tgt : seq[Fr] = newSeq[Fr]( D.domainSize )
|
||||
inverseNTT_worker( D.logDomainSize
|
||||
, 1
|
||||
, D.domainGen
|
||||
, src , 0
|
||||
, buf , 0
|
||||
, tgt , 0 )
|
||||
return tgt
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
185
poly.nim
Normal file
185
poly.nim
Normal file
@ -0,0 +1,185 @@
|
||||
|
||||
#
|
||||
# univariate polynomials over Fr
|
||||
#
|
||||
# constantine's implementation is "somewhat lacking", so we have to
|
||||
# implement these ourselves...
|
||||
#
|
||||
# TODO: more efficient implementations (right now I just want something working)
|
||||
#
|
||||
|
||||
import std/sequtils
|
||||
import std/sugar
|
||||
|
||||
import constantine/math/arithmetic except Fp,Fr
|
||||
import constantine/math/io/io_fields
|
||||
|
||||
import bn128
|
||||
import domain
|
||||
import ntt
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
type
|
||||
Poly* = object
|
||||
coeffs* : seq[Fr]
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func polyDegree*(P: Poly) : int =
|
||||
let xs = P.coeffs ; let n = xs.len
|
||||
var d : int = n-1
|
||||
while isZeroFr(xs[d]) and (d >= 0): d -= 1
|
||||
return d
|
||||
|
||||
func polyIsZero*(P: Poly) : bool =
|
||||
let xs = P.coeffs ; let n = xs.len
|
||||
var b = true
|
||||
for i in 0..<n:
|
||||
if not isZeroFr(xs[i]):
|
||||
b = false
|
||||
break
|
||||
return b
|
||||
|
||||
func polyEqual*(P, Q: Poly) : bool =
|
||||
let xs : seq[Fr] = P.coeffs ; let n = xs.len
|
||||
let ys : seq[Fr] = Q.coeffs ; let m = ys.len
|
||||
var b = true
|
||||
if n >= m:
|
||||
for i in 0..<m: ( if not isEqualFr(xs[i], ys[i]): ( b = false ; break ) )
|
||||
for i in m..<n: ( if not isZeroFr( xs[i] ): ( b = false ; break ) )
|
||||
else:
|
||||
for i in 0..<n: ( if not isEqualFr(xs[i], ys[i]): ( b = false ; break ) )
|
||||
for i in n..<m: ( if not isZeroFr( ys[i]): ( b = false ; break ) )
|
||||
return b
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func polyEvalAt*(P: Poly, x0: Fr): Fr =
|
||||
let cs = P.coeffs ; let n = cs.len
|
||||
var y : Fr = zeroFr
|
||||
var r : Fr = oneFr
|
||||
if n > 0: y = cs[0]
|
||||
for i in 1..<n:
|
||||
r *= x0
|
||||
y += cs[i] * r
|
||||
return y
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func polyNeg*(P: Poly) : Poly =
|
||||
let zs : seq[Fr] = map( P.coeffs , negFr )
|
||||
return Poly(coeffs: zs)
|
||||
|
||||
func polyAdd*(P, Q: Poly) : Poly =
|
||||
let xs = P.coeffs ; let n = xs.len
|
||||
let ys = Q.coeffs ; let m = ys.len
|
||||
var zs : seq[Fr]
|
||||
if n >= m:
|
||||
for i in 0..<m: zs.add( xs[i] + ys[i] )
|
||||
for i in m..<n: zs.add( xs[i] )
|
||||
else:
|
||||
for i in 0..<n: zs.add( xs[i] + ys[i] )
|
||||
for i in n..<m: zs.add( ys[i] )
|
||||
return Poly(coeffs: zs)
|
||||
|
||||
func polySub*(P, Q: Poly) : Poly =
|
||||
let xs = P.coeffs ; let n = xs.len
|
||||
let ys = Q.coeffs ; let m = ys.len
|
||||
var zs : seq[Fr]
|
||||
if n >= m:
|
||||
for i in 0..<m: zs.add( xs[i] - ys[i] )
|
||||
for i in m..<n: zs.add( xs[i] )
|
||||
else:
|
||||
for i in 0..<n: zs.add( xs[i] + ys[i] )
|
||||
for i in n..<m: zs.add( zeroFr - ys[i] )
|
||||
return Poly(coeffs: zs)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func polyScale*(s: Fr, P: Poly): Poly =
|
||||
let zs : seq[Fr] = map( P.coeffs , proc (x: Fr): Fr = s*x )
|
||||
return Poly(coeffs: zs)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func polyMulNaive*(P, Q : Poly): Poly =
|
||||
let xs = P.coeffs ; let n1 = xs.len
|
||||
let ys = Q.coeffs ; let n2 = ys.len
|
||||
var zs : seq[Fr] ; let N = n1 + n2 - 1
|
||||
for k in 0..<N:
|
||||
# 0 <= i <= min(k , n1-1)
|
||||
# 0 <= j <= min(k , n2-1)
|
||||
# k = i + j
|
||||
# 0 >= i = k - j >= k - min(k , n2-1)
|
||||
# 0 >= j = k - i >= k - min(k , n1-1)
|
||||
let A : int = max( 0 , k - min(k , n2-1) )
|
||||
let B : int = min( k , n1-1 )
|
||||
zs[k] = zeroFr
|
||||
for i in A..B:
|
||||
let j = k-i
|
||||
zs[k] += xs[i] * ys[j]
|
||||
return Poly(coeffs: zs)
|
||||
|
||||
func polyMul*(P, Q : Poly): Poly =
|
||||
return polyMulNaive(P, Q)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func `==`*(P, Q: Poly): bool = return polyEqual(P, Q)
|
||||
|
||||
func `+`*(P, Q: Poly): Poly = return polyAdd(P, Q)
|
||||
func `-`*(P, Q: Poly): Poly = return polySub(P, Q)
|
||||
func `*`*(P, Q: Poly): Poly = return polyMul(P, Q)
|
||||
|
||||
func `*`*(s: Fr , P: Poly): Poly = return polyScale(s, P)
|
||||
func `*`*(P: Poly, s: Fr ): Poly = return polyScale(s, P)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
# evaluates a polynomial on an FFT domain
|
||||
func polyForwardNTT*(P: Poly, D: Domain): seq[Fr] =
|
||||
let n = P.coeffs.len
|
||||
assert( n <= D.domainSize , "the domain must be as least as big as the polynomial" )
|
||||
|
||||
if n == D.domainSize:
|
||||
let src : seq[Fr] = P.coeffs
|
||||
return forwardNTT(src, D)
|
||||
else:
|
||||
var src : seq[Fr] = P.coeffs
|
||||
for i in n..<D.domainSize: src.add( zeroFr )
|
||||
return forwardNTT(src, D)
|
||||
|
||||
#---------------------------------------
|
||||
|
||||
# interpolates a polynomial on an FFT domain
|
||||
func polyInverseNTT*(ys: seq[Fr], D: Domain): Poly =
|
||||
let n = ys.len
|
||||
assert( n == D.domainSize , "the domain must be same size as the input" )
|
||||
let tgt = inverseNTT(ys, D)
|
||||
return Poly(coeffs: tgt)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
proc sanityCheckOneHalf*() =
|
||||
let two = oneFr + oneFr
|
||||
let invTwo = oneHalfFr
|
||||
echo(toDecimalFr(two))
|
||||
echo(toDecimalFr(invTwo * two))
|
||||
echo(toHex(invTwo))
|
||||
|
||||
proc sanityCheckPolys*() =
|
||||
var js : seq[int] = toSeq(101..108)
|
||||
let cs : seq[Fr] = map( js, intToFr )
|
||||
let P : Poly = Poly( coeffs:cs )
|
||||
let D : Domain = createDomain(8)
|
||||
let xs : seq[Fr] = D.enumerateDomain()
|
||||
let ys : seq[Fr] = collect( newSeq, (for x in xs: polyEvalAt(P,x)) )
|
||||
let zs : seq[Fr] = polyForwardNTT(P ,D)
|
||||
let Q : Poly = polyInverseNTT(zs,D)
|
||||
debugPrintSeqFr("xs", xs)
|
||||
debugPrintSeqFr("ys", ys)
|
||||
debugPrintSeqFr("zs", zs)
|
||||
debugPrintSeqFr("us", Q.coeffs)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
Loading…
x
Reference in New Issue
Block a user