polynomials and NTT

This commit is contained in:
Balazs Komuves 2023-11-10 10:12:39 +01:00
parent ba04191b72
commit 30ebd2793e
4 changed files with 686 additions and 36 deletions

353
bn128.nim
View File

@ -1,6 +1,8 @@
import std/bitops
import std/strutils
import std/streams
import std/random
import constantine/math/arithmetic
import constantine/math/io/io_fields
@ -8,27 +10,44 @@ import constantine/math/io/io_bigints
import constantine/math/config/curves
import constantine/math/config/type_ff as tff
import constantine/math/extension_fields/towers as ext
import constantine/math/elliptic/ec_shortweierstrass_affine as ell
import constantine/math/extension_fields/towers as ext
import constantine/math/elliptic/ec_shortweierstrass_affine as aff
import constantine/math/elliptic/ec_shortweierstrass_projective as prj
import constantine/math/pairings/pairings_bn as ate
import constantine/math/elliptic/ec_multi_scalar_mul as msm
import constantine/math/elliptic/ec_scalar_mul as scl
#-------------------------------------------------------------------------------
type B* = BigInt[256]
type Fr* = tff.Fr[BN254Snarks]
type Fp* = tff.Fp[BN254Snarks]
type Fp2* = ext.QuadraticExt[Fp]
type G1* = ell.ECP_ShortW_Aff[Fp , ell.G1]
type G2* = ell.ECP_ShortW_Aff[Fp2, ell.G2]
type B* = BigInt[256]
type Fr* = tff.Fr[BN254Snarks]
type Fp* = tff.Fp[BN254Snarks]
func mkFp2(i: Fp, u: Fp) : Fp2 =
type Fp2* = ext.QuadraticExt[Fp]
type Fp12* = ext.Fp12[BN254Snarks]
type G1* = aff.ECP_ShortW_Aff[Fp , aff.G1]
type G2* = aff.ECP_ShortW_Aff[Fp2, aff.G2]
type ProjG1* = prj.ECP_ShortW_Prj[Fp , prj.G1]
type ProjG2* = prj.ECP_ShortW_Prj[Fp2, prj.G2]
func mkFp2* (i: Fp, u: Fp) : Fp2 =
let c : array[2, Fp] = [i,u]
return ext.QuadraticExt[Fp]( coords: c )
func unsafeMkG1( X, Y: Fp ) : G1 =
return ell.ECP_ShortW_Aff[Fp, ell.G1](x: X, y: Y)
func unsafeMkG1* ( X, Y: Fp ) : G1 =
return aff.ECP_ShortW_Aff[Fp, aff.G1](x: X, y: Y)
func unsafeMkG2( X, Y: Fp2 ) : G2 =
return ell.ECP_ShortW_Aff[Fp2, ell.G2](x: X, y: Y)
func unsafeMkG2* ( X, Y: Fp2 ) : G2 =
return aff.ECP_ShortW_Aff[Fp2, aff.G2](x: X, y: Y)
#-------------------------------------------------------------------------------
func pairing* (p: G1, q: G2) : Fp12 =
var t : Fp12
pairing_bn( t, p, q )
return t
#-------------------------------------------------------------------------------
@ -37,29 +56,70 @@ const primeR* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d2833e84879b97
#-------------------------------------------------------------------------------
const zeroFp* : Fp = fromHex( Fp, "0x00" )
const zeroFr* : Fr = fromHex( Fr, "0x00" )
const oneFp* : Fp = fromHex( Fp, "0x01" )
const oneFr* : Fr = fromHex( Fr, "0x01" )
const zeroFp* : Fp = fromHex( Fp, "0x00" )
const zeroFr* : Fr = fromHex( Fr, "0x00" )
const oneFp* : Fp = fromHex( Fp, "0x01" )
const oneFr* : Fr = fromHex( Fr, "0x01" )
const zeroFp2* : Fp2 = mkFp2( zeroFp, zeroFp )
const oneFp2* : Fp2 = mkFp2( oneFp , zeroFp )
const infG1* : G1 = unsafeMkG1( zeroFp , zeroFp )
const infG2* : G2 = unsafeMkG2( zeroFp2 , zeroFp2 )
#-------------------------------------------------------------------------------
func intToFp*(a: int) : Fp =
func intToFp*(a: int): Fp =
var y : Fp
y.fromInt(a)
return y
func intToFr*(a: int) : Fr =
func intToFr*(a: int): Fr =
var y : Fr
y.fromInt(a)
return y
#-------------------------------------------------------------------------------
func isZeroFp*(x: Fp): bool = bool(isZero(x))
func isZeroFr*(x: Fr): bool = bool(isZero(x))
func isEquaLFp*(x, y: Fp): bool = bool(x == y)
func isEquaLFr*(x, y: Fr): bool = bool(x == y)
#-------------------------------------------------------------------------------
func `+`*(x, y: Fp): Fp = ( var z : Fp = x ; z += y ; return z )
func `+`*(x, y: Fr): Fr = ( var z : Fr = x ; z += y ; return z )
func `-`*(x, y: Fp): Fp = ( var z : Fp = x ; z -= y ; return z )
func `-`*(x, y: Fr): Fr = ( var z : Fr = x ; z -= y ; return z )
func `*`*(x, y: Fp): Fp = ( var z : Fp = x ; z *= y ; return z )
func `*`*(x, y: Fr): Fr = ( var z : Fr = x ; z *= y ; return z )
func negFp*(y: Fp): Fp = ( var z : Fp = zeroFp ; z -= y ; return z )
func negFr*(y: Fr): Fr = ( var z : Fr = zeroFr ; z -= y ; return z )
func invFp*(y: Fp): Fp = ( var z : Fp = y ; inv(z) ; return z )
func invFr*(y: Fr): Fr = ( var z : Fr = y ; inv(z) ; return z )
# template/generic instantiation of `pow_vartime` from here
# /Users/bkomuves/.nimble/pkgs/constantine-0.0.1/constantine/math/arithmetic/finite_fields.nim(389, 7) template/generic instantiation of `fieldMod` from here
# /Users/bkomuves/.nimble/pkgs/constantine-0.0.1/constantine/math/config/curves_prop_field_derived.nim(67, 5) Error: undeclared identifier: 'getCurveOrder'
# ...
func smallPowFr*(base: Fr, expo: uint): Fr =
var a : Fr = oneFr
var s : Fr = base
var e : uint = expo
while (e > 0):
if bitand(e,1) > 0: a *= s
e = (e shr 1)
square(s)
return a
#-------------------------------------------------------------------------------
func toDecimalBig*[n](a : BigInt[n]): string =
var s : string = toDecimal(a)
s = s.strip( leading=true, trailing=false, chars={'0'} )
@ -80,16 +140,93 @@ func toDecimalFr*(a : Fr): string =
#-------------------------------------------------------------------------------
proc debugPrintSeqFr*(msg: string, xs: seq[Fr]) =
echo "---------------------"
echo msg
for x in xs:
echo(" " & toDecimalFr(x))
#-------------------------------------------------------------------------------
# random values
var randomInitialized : bool = false
var randomState : Rand = initRand( 12345 )
proc rndUint64() : uint64 =
return randomState.next()
proc initializeRandomIfNecessary() =
if not randomInitialized:
randomState = initRand()
randomInitialized = true
#----------------------------| 01234567890abcdf01234567890abcdf01234567890abcdf01234567890abcdf
const m64 : B = fromHex( B, "0x0000000000000000000000000000000000000000000000010000000000000000", bigEndian )
const m128 : B = fromHex( B, "0x0000000000000000000000000000000100000000000000000000000000000000", bigEndian )
const m192 : B = fromHex( B, "0x0000000000000001000000000000000000000000000000000000000000000000", bigEndian )
#----------------------------| 01234567890abcdf01234567890abcdf01234567890abcdf01234567890abcdf
proc randBig*[bits: static int](): BigInt[bits] =
initializeRandomIfNecessary()
let a0 : uint64 = rndUint64()
let a1 : uint64 = rndUint64()
let a2 : uint64 = rndUint64()
let a3 : uint64 = rndUint64()
# echo((a0,a1,a2,a3))
var b0 : BigInt[bits] ; b0.fromUint(a0)
var b1 : BigInt[bits] ; b1.fromUint(a1)
var b2 : BigInt[bits] ; b2.fromUint(a2)
var b3 : BigInt[bits] ; b3.fromUint(a3)
# constantine doesn't appear to have left shift....
var c1,c2,c3 : BigInt[bits]
prod( c1 , b1 , m64 )
prod( c2 , b2 , m128 )
prod( c3 , b3 , m192 )
var d : BigInt[bits]
d = b0
d += c1
d += c2
d += c3
return d
proc randFr*(): Fr =
let b : BigInt[254] = randBig[254]()
var y : Fr
y.fromBig( b )
return y
proc testRandom*() =
for i in 1..20:
let x = randFr()
echo(x.toHex())
echo("-------------------")
echo(primeR.toHex())
#-------------------------------------------------------------------------------
func checkCurveEqG1*( x, y: Fp ) : bool =
var x2 : Fp = x ; square(x2);
var y2 : Fp = y ; square(y2);
var x3 : Fp = x2 ; x3 *= x;
var eq : Fp
eq = x3
eq += intToFp(3)
eq -= y2
# echo("eq = ",toDecimalFp(eq))
return (bool(isZero(eq)))
if bool(isZero(x)) and bool(isZero(y)):
# the point at infinity is on the curve by definition
return true
else:
var x2 : Fp = x ; square(x2);
var y2 : Fp = y ; square(y2);
var x3 : Fp = x2 ; x3 *= x;
var eq : Fp
eq = x3
eq += intToFp(3)
eq -= y2
# echo("eq = ",toDecimalFp(eq))
return (bool(isZero(eq)))
#---------------------------------------
# y^2 = x^3 + B
# B = b1 + bu*u
@ -100,14 +237,18 @@ const twistCoeffB_u : Fp = fromHex(Fp, "0x009713b03af0fed4cd2cafadeed8fdf4a74fa
const twistCoeffB : Fp2 = mkFp2( twistCoeffB_1 , twistCoeffB_u )
func checkCurveEqG2*( x, y: Fp2 ) : bool =
var x2 : Fp2 = x ; square(x2);
var y2 : Fp2 = y ; square(y2);
var x3 : Fp2 = x2 ; x3 *= x;
var eq : Fp2
eq = x3
eq += twistCoeffB
eq -= y2
return (bool(isZero(eq)))
if bool(isZero(x)) and bool(isZero(y)):
# the point at infinity is on the curve by definition
return true
else:
var x2 : Fp2 = x ; square(x2);
var y2 : Fp2 = y ; square(y2);
var x3 : Fp2 = x2 ; x3 *= x;
var eq : Fp2
eq = x3
eq += twistCoeffB
eq -= y2
return (bool(isZero(eq)))
#-------------------------------------------------------------------------------
@ -125,6 +266,14 @@ func mkG2( x, y: Fp2 ) : G2 =
assert( checkCurveEqG2(x,y) , "mkG2: not a G2 curve point" )
return unsafeMkG2(x,y)
#-------------------------------------------------------------------------------
func isOnCurveG1* ( p: G1 ) : bool =
return checkCurveEqG1( p.x, p.y )
func isOnCurveG2* ( p: G2 ) : bool =
return checkCurveEqG2( p.x, p.y )
#-------------------------------------------------------------------------------
# Dealing with Montgomery representation
#
@ -281,4 +430,136 @@ proc loadPointsG2*( len: int, stream: Stream ) : seq[G2] =
points.add( loadPointG2(stream) )
return points
#===============================================================================
func addG1*(p,q: G1): G1 =
var r, x, y : ProjG1
prj.fromAffine(x, p)
prj.fromAffine(y, q)
prj.sum(r, x, y)
var s : G1
prj.affine(s, r)
return s
func addG2*(p,q: G2): G2 =
var r, x, y : ProjG2
prj.fromAffine(x, p)
prj.fromAffine(y, q)
prj.sum(r, x, y)
var s : G2
prj.affine(s, r)
return s
func negG1*(p: G1): G1 =
var r : G1 = p
neg(r)
return r
func negG2*(p: G2): G2 =
var r : G2 = p
neg(r)
return r
func `+`*(p,q: G1): G1 = addG1(p,q)
func `+`*(p,q: G2): G2 = addG2(p,q)
func `+=`*(p: var G1, q: G1) = p = addG1(p,q)
func `+=`*(p: var G2, q: G2) = p = addG2(p,q)
func `-=`*(p: var G1, q: G1) = p = addG1(p,negG1(q))
func `-=`*(p: var G2, q: G2) = p = addG2(p,negG2(q))
#-------------------------------------------------------------------------------
# func msmG1( coeffs: seq[Fr] , points: seq[G1] ): G1 =
#
# let N = coeffs.len
# assert( N == points.len, "incompatible sequence lengths" )
#
# # var arr1 = toOpenArray(coeffs, 0, N-1)
# # var arr2 = toOpenArray(points, 0, N-1)
#
# var bigcfs : seq[BigInt[254]]
# for x in coeffs:
# bigcfs.add( x.toBig() )
#
# var r : G1
#
# # [Fp,aff.G1]
# msm.multiScalarMul_vartime( r,
# toOpenArray(bigcfs, 0, N-1),
# toOpenArray(points, 0, N-1) )
#
# return r
#-------------------------------------------------------------------------------
#
# (affine) scalar multiplication
#
func `**`*( coeff: Fr , point: G1 ) : G1 =
var q : ProjG1
prj.fromAffine( q , point )
scl.scalarMul( q , coeff.toBig() )
var r : G1
prj.affine( r, q )
return r
func `**`*( coeff: Fr , point: G2 ) : G2 =
var q : ProjG2
prj.fromAffine( q , point )
scl.scalarMul( q , coeff.toBig() )
var r : G2
prj.affine( r, q )
return r
#-------------------------------------------------------------------------------
func msmNaiveG1( coeffs: seq[Fr] , points: seq[G1] ): G1 =
let N = coeffs.len
assert( N == points.len, "incompatible sequence lengths" )
var s : ProjG1
s.setInf()
for i in 0..<N:
var t : ProjG1
prj.fromAffine( t, points[i] )
scl.scalarMul( t , coeffs[i].toBig() )
s += t
var r : G1
prj.affine( r, s )
return r
#---------------------------------------
func msmNaiveG2( coeffs: seq[Fr] , points: seq[G2] ): G2 =
let N = coeffs.len
assert( N == points.len, "incompatible sequence lengths" )
var s : ProjG2
s.setInf()
for i in 0..<N:
var t : ProjG2
prj.fromAffine( t, points[i] )
scl.scalarMul( t , coeffs[i].toBig() )
s += t
var r : G2
prj.affine( r, s)
return r
#-------------------------------------------------------------------------------
# TODO: proper MSM implementation (couldn't make constantine work at first...)
func msmG1*( coeffs: seq[Fr] , points: seq[G1] ): G1 =
return msmNaiveG1( coeffs, points )
func msmG2*( coeffs: seq[Fr] , points: seq[G2] ): G2 =
return msmNaiveG2( coeffs, points )
#-------------------------------------------------------------------------------

51
domain.nim Normal file
View File

@ -0,0 +1,51 @@
#
# power-of-two sized multiplicative FFT domains in the scalar field
#
import constantine/math/io/io_bigints
import constantine/math/arithmetic except Fp,Fr
import constantine/math/io/io_fields except Fp,Fr
import ./bn128
import ./misc
#-------------------------------------------------------------------------------
type
Domain* = object
domainSize* : int
logDomainSize* : int
domainGen* : Fr
#-------------------------------------------------------------------------------
# the generator of the multiplicative subgroup with size `2^28`
const gen28 : Fr = fromHex( Fr, "0x2a3c09f0a58a7e8500e0a7eb8ef62abc402d111e41112ed49bd61b6e725b19f0" )
func createDomain*(size: int): Domain =
let log2 = ceilingLog2(size)
assert( (1 shl log2) == size , "domain must have a power-of-two size" )
let expo : uint = 1'u shl (28 - log2)
let gen : Fr = smallPowFr(gen28, expo)
let halfSize = size div 2
let a : Fr = smallPowFr(gen, uint(size ))
let b : Fr = smallPowFr(gen, uint(halfSize))
assert( bool(a == oneFr) , "domain generator sanity check /A" )
assert( not bool(b == oneFr) , "domain generator sanity check /B" )
return Domain( domainSize:size, logDomainSize:log2, domainGen:gen )
#-------------------------------------------------------------------------------
func enumerateDomain*(D: Domain): seq[Fr] =
var xs : seq[Fr] = newSeq[Fr](D.domainSize)
var g : Fr = oneFr
for i in 0..<D.domainSize:
xs[i] = g
g *= D.domainGen
return xs
#-------------------------------------------------------------------------------

133
ntt.nim Normal file
View File

@ -0,0 +1,133 @@
#
# Number-theoretic transform
#
import std/sequtils
import constantine/math/arithmetic except Fp,Fr
import constantine/math/io/io_fields
import bn128
import domain
#-------------------------------------------------------------------------------
func forwardNTT_worker( m: int
, srcStride: int
, gen: Fr
, src: seq[Fr] , srcOfs: int
, buf: var seq[Fr] , bufOfs: int
, tgt: var seq[Fr] , tgtOfs: int ) =
case m
of 0:
tgt[tgtOfs] = src[srcOfs]
of 1:
tgt[tgtOfs ] = src[srcOfs] + src[srcOfs+srcStride]
tgt[tgtOfs+1] = src[srcOfs] - src[srcOfs+srcStride]
else:
let N : int = 1 shl m
let halfN : int = 1 shl (m-1)
var gpow : Fr = gen
square(gpow)
forwardNTT_worker( m-1
, srcStride shl 1
, gpow
, src , srcOfs
, buf , bufOfs + N
, buf , bufOfs )
forwardNTT_worker( m-1
, srcStride shl 1
, gpow
, src , srcOfs + srcStride
, buf , bufOfs + N
, buf , bufOfs + halfN )
gpow = oneFr
for j in 0..<halfN:
let y : Fr = gpow * buf[bufOfs+j+halfN]
tgt[tgtOfs+j ] = buf[bufOfs+j] + y
tgt[tgtOfs+j+halfN] = buf[bufOfs+j] - y
gpow *= gen
#---------------------------------------
# forward number-theoretical transform (corresponds to polynomial evaluation)
func forwardNTT*(src: seq[Fr], D: Domain): seq[Fr] =
assert( D.domainSize == (1 shl D.logDomainSize) , "domain must have a power-of-two size" )
assert( D.domainSize == src.len , "input must have the same size as the domain" )
var buf : seq[Fr] = newSeq[Fr]( 2 * D.domainSize )
var tgt : seq[Fr] = newSeq[Fr]( D.domainSize )
forwardNTT_worker( D.logDomainSize
, 1
, D.domainGen
, src , 0
, buf , 0
, tgt , 0 )
return tgt
#-------------------------------------------------------------------------------
const oneHalfFr* : Fr = fromHex(Fr, "0x183227397098d014dc2822db40c0ac2e9419f4243cdcb848a1f0fac9f8000001")
func inverseNTT_worker( m: int
, tgtStride: int
, gen: Fr
, src: seq[Fr] , srcOfs: int
, buf: var seq[Fr] , bufOfs: int
, tgt: var seq[Fr] , tgtOfs: int ) =
case m
of 0:
tgt[tgtOfs] = src[srcOfs]
of 1:
tgt[tgtOfs ] = ( src[srcOfs] + src[srcOfs+1] ) * oneHalfFr
tgt[tgtOfs+tgtStride] = ( src[srcOfs] - src[srcOfs+1] ) * oneHalfFr
else:
let N : int = 1 shl m
let halfN : int = 1 shl (m-1)
let ginv : Fr = invFr(gen)
var gpow : Fr = oneHalfFr
for j in 0..<halfN:
buf[bufOfs+j ] = ( src[srcOfs+j] + src[srcOfs+j+halfN] ) * oneHalfFr
buf[bufOfs+j+halfN] = ( src[srcOfs+j] - src[srcOfs+j+halfN] ) * gpow
gpow *= ginv
gpow = gen
square(gpow)
inverseNTT_worker( m-1
, tgtStride shl 1
, gpow
, buf , bufOfs
, buf , bufOfs + N
, tgt , tgtOfs )
inverseNTT_worker( m-1
, tgtStride shl 1
, gpow
, buf , bufOfs + halfN
, buf , bufOfs + N
, tgt , tgtOfs + tgtStride )
#---------------------------------------
# inverse number-theoretical transform (corresponds to polynomial interpolation)
func inverseNTT*(src: seq[Fr], D: Domain): seq[Fr] =
assert( D.domainSize == (1 shl D.logDomainSize) , "domain must have a power-of-two size" )
assert( D.domainSize == src.len , "input must have the same size as the domain" )
var buf : seq[Fr] = newSeq[Fr]( 2 * D.domainSize )
var tgt : seq[Fr] = newSeq[Fr]( D.domainSize )
inverseNTT_worker( D.logDomainSize
, 1
, D.domainGen
, src , 0
, buf , 0
, tgt , 0 )
return tgt
#-------------------------------------------------------------------------------

185
poly.nim Normal file
View File

@ -0,0 +1,185 @@
#
# univariate polynomials over Fr
#
# constantine's implementation is "somewhat lacking", so we have to
# implement these ourselves...
#
# TODO: more efficient implementations (right now I just want something working)
#
import std/sequtils
import std/sugar
import constantine/math/arithmetic except Fp,Fr
import constantine/math/io/io_fields
import bn128
import domain
import ntt
#-------------------------------------------------------------------------------
type
Poly* = object
coeffs* : seq[Fr]
#-------------------------------------------------------------------------------
func polyDegree*(P: Poly) : int =
let xs = P.coeffs ; let n = xs.len
var d : int = n-1
while isZeroFr(xs[d]) and (d >= 0): d -= 1
return d
func polyIsZero*(P: Poly) : bool =
let xs = P.coeffs ; let n = xs.len
var b = true
for i in 0..<n:
if not isZeroFr(xs[i]):
b = false
break
return b
func polyEqual*(P, Q: Poly) : bool =
let xs : seq[Fr] = P.coeffs ; let n = xs.len
let ys : seq[Fr] = Q.coeffs ; let m = ys.len
var b = true
if n >= m:
for i in 0..<m: ( if not isEqualFr(xs[i], ys[i]): ( b = false ; break ) )
for i in m..<n: ( if not isZeroFr( xs[i] ): ( b = false ; break ) )
else:
for i in 0..<n: ( if not isEqualFr(xs[i], ys[i]): ( b = false ; break ) )
for i in n..<m: ( if not isZeroFr( ys[i]): ( b = false ; break ) )
return b
#-------------------------------------------------------------------------------
func polyEvalAt*(P: Poly, x0: Fr): Fr =
let cs = P.coeffs ; let n = cs.len
var y : Fr = zeroFr
var r : Fr = oneFr
if n > 0: y = cs[0]
for i in 1..<n:
r *= x0
y += cs[i] * r
return y
#-------------------------------------------------------------------------------
func polyNeg*(P: Poly) : Poly =
let zs : seq[Fr] = map( P.coeffs , negFr )
return Poly(coeffs: zs)
func polyAdd*(P, Q: Poly) : Poly =
let xs = P.coeffs ; let n = xs.len
let ys = Q.coeffs ; let m = ys.len
var zs : seq[Fr]
if n >= m:
for i in 0..<m: zs.add( xs[i] + ys[i] )
for i in m..<n: zs.add( xs[i] )
else:
for i in 0..<n: zs.add( xs[i] + ys[i] )
for i in n..<m: zs.add( ys[i] )
return Poly(coeffs: zs)
func polySub*(P, Q: Poly) : Poly =
let xs = P.coeffs ; let n = xs.len
let ys = Q.coeffs ; let m = ys.len
var zs : seq[Fr]
if n >= m:
for i in 0..<m: zs.add( xs[i] - ys[i] )
for i in m..<n: zs.add( xs[i] )
else:
for i in 0..<n: zs.add( xs[i] + ys[i] )
for i in n..<m: zs.add( zeroFr - ys[i] )
return Poly(coeffs: zs)
#-------------------------------------------------------------------------------
func polyScale*(s: Fr, P: Poly): Poly =
let zs : seq[Fr] = map( P.coeffs , proc (x: Fr): Fr = s*x )
return Poly(coeffs: zs)
#-------------------------------------------------------------------------------
func polyMulNaive*(P, Q : Poly): Poly =
let xs = P.coeffs ; let n1 = xs.len
let ys = Q.coeffs ; let n2 = ys.len
var zs : seq[Fr] ; let N = n1 + n2 - 1
for k in 0..<N:
# 0 <= i <= min(k , n1-1)
# 0 <= j <= min(k , n2-1)
# k = i + j
# 0 >= i = k - j >= k - min(k , n2-1)
# 0 >= j = k - i >= k - min(k , n1-1)
let A : int = max( 0 , k - min(k , n2-1) )
let B : int = min( k , n1-1 )
zs[k] = zeroFr
for i in A..B:
let j = k-i
zs[k] += xs[i] * ys[j]
return Poly(coeffs: zs)
func polyMul*(P, Q : Poly): Poly =
return polyMulNaive(P, Q)
#-------------------------------------------------------------------------------
func `==`*(P, Q: Poly): bool = return polyEqual(P, Q)
func `+`*(P, Q: Poly): Poly = return polyAdd(P, Q)
func `-`*(P, Q: Poly): Poly = return polySub(P, Q)
func `*`*(P, Q: Poly): Poly = return polyMul(P, Q)
func `*`*(s: Fr , P: Poly): Poly = return polyScale(s, P)
func `*`*(P: Poly, s: Fr ): Poly = return polyScale(s, P)
#-------------------------------------------------------------------------------
# evaluates a polynomial on an FFT domain
func polyForwardNTT*(P: Poly, D: Domain): seq[Fr] =
let n = P.coeffs.len
assert( n <= D.domainSize , "the domain must be as least as big as the polynomial" )
if n == D.domainSize:
let src : seq[Fr] = P.coeffs
return forwardNTT(src, D)
else:
var src : seq[Fr] = P.coeffs
for i in n..<D.domainSize: src.add( zeroFr )
return forwardNTT(src, D)
#---------------------------------------
# interpolates a polynomial on an FFT domain
func polyInverseNTT*(ys: seq[Fr], D: Domain): Poly =
let n = ys.len
assert( n == D.domainSize , "the domain must be same size as the input" )
let tgt = inverseNTT(ys, D)
return Poly(coeffs: tgt)
#-------------------------------------------------------------------------------
proc sanityCheckOneHalf*() =
let two = oneFr + oneFr
let invTwo = oneHalfFr
echo(toDecimalFr(two))
echo(toDecimalFr(invTwo * two))
echo(toHex(invTwo))
proc sanityCheckPolys*() =
var js : seq[int] = toSeq(101..108)
let cs : seq[Fr] = map( js, intToFr )
let P : Poly = Poly( coeffs:cs )
let D : Domain = createDomain(8)
let xs : seq[Fr] = D.enumerateDomain()
let ys : seq[Fr] = collect( newSeq, (for x in xs: polyEvalAt(P,x)) )
let zs : seq[Fr] = polyForwardNTT(P ,D)
let Q : Poly = polyInverseNTT(zs,D)
debugPrintSeqFr("xs", xs)
debugPrintSeqFr("ys", ys)
debugPrintSeqFr("zs", zs)
debugPrintSeqFr("us", Q.coeffs)
#-------------------------------------------------------------------------------