optimized the NTT routines

This commit is contained in:
Balazs Komuves 2024-02-29 20:13:50 +01:00
parent 9d743247e9
commit cfd30a045e
No known key found for this signature in database
GPG Key ID: F63B7AEF18435562

View File

@ -16,7 +16,7 @@ import groth16/math/domain
func forwardNTT_worker( m: int
, srcStride: int
, gen: Fr
, gpows: seq[Fr]
, src: seq[Fr] , srcOfs: int
, buf: var seq[Fr] , bufOfs: int
, tgt: var seq[Fr] , tgtOfs: int ) =
@ -32,26 +32,22 @@ func forwardNTT_worker( m: int
else:
let N : int = 1 shl m
let halfN : int = 1 shl (m-1)
var gpow : Fr = gen
square(gpow)
forwardNTT_worker( m-1
, srcStride shl 1
, gpow
, gpows
, src , srcOfs
, buf , bufOfs + N
, buf , bufOfs )
forwardNTT_worker( m-1
, srcStride shl 1
, gpow
, gpows
, src , srcOfs + srcStride
, buf , bufOfs + N
, buf , bufOfs + halfN )
gpow = oneFr
for j in 0..<halfN:
let y : Fr = gpow * buf[bufOfs+j+halfN]
let y : Fr = gpows[j*srcStride] * buf[bufOfs+j+halfN]
tgt[tgtOfs+j ] = buf[bufOfs+j] + y
tgt[tgtOfs+j+halfN] = buf[bufOfs+j] - y
gpow *= gen
#---------------------------------------
@ -61,9 +57,20 @@ func forwardNTT*(src: seq[Fr], D: Domain): seq[Fr] =
assert( D.domainSize == src.len , "input must have the same size as the domain" )
var buf : seq[Fr] = newSeq[Fr]( 2 * D.domainSize )
var tgt : seq[Fr] = newSeq[Fr]( D.domainSize )
# precalc powers of gen
let N = D.domainSize
let halFN = N div 2
var gpows : seq[Fr] = newSeq[Fr]( halfN )
var x : Fr = oneFr
let gen : Fr = D.domainGen
for i in 0..<halfN:
gpows[i] = x
x *= gen
forwardNTT_worker( D.logDomainSize
, 1
, D.domainGen
, gpows
, src , 0
, buf , 0
, tgt , 0 )
@ -89,7 +96,7 @@ const oneHalfFr* : Fr = fromHex(Fr, "0x183227397098d014dc2822db40c0ac2e9419f4243
func inverseNTT_worker( m: int
, tgtStride: int
, gen: Fr
, gpows: seq[Fr]
, src: seq[Fr] , srcOfs: int
, buf: var seq[Fr] , bufOfs: int
, tgt: var seq[Fr] , tgtOfs: int ) =
@ -98,35 +105,30 @@ func inverseNTT_worker( m: int
of 0:
tgt[tgtOfs] = src[srcOfs]
# TODO: faster division by 2
of 1:
tgt[tgtOfs ] = ( src[srcOfs] + src[srcOfs+1] ) * oneHalfFr
tgt[tgtOfs+tgtStride] = ( src[srcOfs] - src[srcOfs+1] ) * oneHalfFr
tgt[tgtOfs ] = ( src[srcOfs] + src[srcOfs+1] )
tgt[tgtOfs+tgtStride] = ( src[srcOfs] - src[srcOfs+1] )
div2( tgt[tgtOfs ] )
div2( tgt[tgtOfs+tgtStride] )
else:
let N : int = 1 shl m
let halfN : int = 1 shl (m-1)
let ginv : Fr = invFr(gen)
var gpow : Fr = oneHalfFr
# TODO: precalculate the gpow vector for repeated iNTT-s ?
for j in 0..<halfN:
buf[bufOfs+j ] = ( src[srcOfs+j] + src[srcOfs+j+halfN] ) * oneHalfFr
buf[bufOfs+j+halfN] = ( src[srcOfs+j] - src[srcOfs+j+halfN] ) * gpow
gpow *= ginv
gpow = gen
square(gpow)
buf[bufOfs+j ] = ( src[srcOfs+j] + src[srcOfs+j+halfN] )
buf[bufOfs+j+halfN] = ( src[srcOfs+j] - src[srcOfs+j+halfN] ) * gpows[ j*tgtStride ]
div2( buf[bufOfs+j ] )
inverseNTT_worker( m-1
, tgtStride shl 1
, gpow
, gpows
, buf , bufOfs
, buf , bufOfs + N
, tgt , tgtOfs )
inverseNTT_worker( m-1
, tgtStride shl 1
, gpow
, gpows
, buf , bufOfs + halfN
, buf , bufOfs + N
, tgt , tgtOfs + tgtStride )
@ -139,9 +141,20 @@ func inverseNTT*(src: seq[Fr], D: Domain): seq[Fr] =
assert( D.domainSize == src.len , "input must have the same size as the domain" )
var buf : seq[Fr] = newSeq[Fr]( 2 * D.domainSize )
var tgt : seq[Fr] = newSeq[Fr]( D.domainSize )
# precalc 1/2 times powers of gen^-1
let N = D.domainSize
let halFN = N div 2
var gpows : seq[Fr] = newSeq[Fr]( halfN )
var x : Fr = oneHalfFr
let ginv : Fr = invFr( D.domainGen )
for i in 0..<halfN:
gpows[i] = x
x *= ginv
inverseNTT_worker( D.logDomainSize
, 1
, D.domainGen
, gpows
, src , 0
, buf , 0
, tgt , 0 )