Merge pull request #2 from codex-storage/fix-msm

use constantine msm
This commit is contained in:
Balazs Komuves 2023-11-14 09:37:36 +01:00 committed by GitHub
commit 7379bc04ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

224
bn128.nim
View File

@ -17,10 +17,13 @@ import std/sequtils
import std/streams
import std/random
import constantine/platforms/abstractions
import constantine/math/isogenies/frobenius
import constantine/math/arithmetic
import constantine/math/io/io_fields
import constantine/math/io/io_bigints
import constantine/math/config/curves
import constantine/math/config/curves
import constantine/math/config/type_ff as tff
import constantine/math/extension_fields/towers as ext
@ -28,7 +31,7 @@ import constantine/math/elliptic/ec_shortweierstrass_affine as aff
import constantine/math/elliptic/ec_shortweierstrass_projective as prj
import constantine/math/pairings/pairings_bn as ate
import constantine/math/elliptic/ec_scalar_mul as scl
# import constantine/math/elliptic/ec_multi_scalar_mul as msm
import constantine/math/elliptic/ec_multi_scalar_mul as msm
#-------------------------------------------------------------------------------
@ -49,10 +52,10 @@ func mkFp2* (i: Fp, u: Fp) : Fp2 =
let c : array[2, Fp] = [i,u]
return ext.QuadraticExt[Fp]( coords: c )
func unsafeMkG1* ( X, Y: Fp ) : G1 =
func unsafeMkG1* ( X, Y: Fp ) : G1 =
return aff.ECP_ShortW_Aff[Fp, aff.G1](x: X, y: Y)
func unsafeMkG2* ( X, Y: Fp2 ) : G2 =
func unsafeMkG2* ( X, Y: Fp2 ) : G2 =
return aff.ECP_ShortW_Aff[Fp2, aff.G2](x: X, y: Y)
#-------------------------------------------------------------------------------
@ -113,25 +116,25 @@ func `===`*(x, y: Fr): bool = isEqualFr(x,y)
#-------------------
func isEqualFpSeq*(xs, ys: seq[Fp]): bool =
func isEqualFpSeq*(xs, ys: seq[Fp]): bool =
let n = xs.len
assert( n == ys.len )
var b = true
for i in 0..<n:
for i in 0..<n:
if not bool(xs[i] == ys[i]):
b = false
break
return b
return b
func isEqualFrSeq*(xs, ys: seq[Fr]): bool =
func isEqualFrSeq*(xs, ys: seq[Fr]): bool =
let n = xs.len
assert( n == ys.len )
var b = true
for i in 0..<n:
for i in 0..<n:
if not bool(xs[i] == ys[i]):
b = false
break
return b
return b
func `===`*(xs, ys: seq[Fp]): bool = isEqualFpSeq(xs,ys)
func `===`*(xs, ys: seq[Fr]): bool = isEqualFrSeq(xs,ys)
@ -142,7 +145,7 @@ func `===`*(xs, ys: seq[Fr]): bool = isEqualFrSeq(xs,ys)
func `+`*[n](x, y: BigInt[n] ): BigInt[n] = ( var z : BigInt[n] = x ; z += y ; return z )
func `+`*(x, y: Fp): Fp = ( var z : Fp = x ; z += y ; return z )
func `+`*(x, y: Fr): Fr = ( var z : Fr = x ; z += y ; return z )
#func `-`*(x, y: B ): B = ( var z : B = x ; z -= y ; return z )
func `-`*[n](x, y: BigInt[n] ): BigInt[n] = ( var z : BigInt[n] = x ; z -= y ; return z )
func `-`*(x, y: Fp): Fp = ( var z : Fp = x ; z -= y ; return z )
@ -161,17 +164,17 @@ func invFr*(y: Fr): Fr = ( var z : Fr = y ; inv(z) ; return z )
# /Users/bkomuves/.nimble/pkgs/constantine-0.0.1/constantine/math/arithmetic/finite_fields.nim(389, 7) template/generic instantiation of `fieldMod` from here
# /Users/bkomuves/.nimble/pkgs/constantine-0.0.1/constantine/math/config/curves_prop_field_derived.nim(67, 5) Error: undeclared identifier: 'getCurveOrder'
# ...
func smallPowFr*(base: Fr, expo: uint): Fr =
func smallPowFr*(base: Fr, expo: uint): Fr =
var a : Fr = oneFr
var s : Fr = base
var e : uint = expo
var e : uint = expo
while (e > 0):
if bitand(e,1) > 0: a *= s
e = (e shr 1)
square(s)
return a
func smallPowFr*(base: Fr, expo: int): Fr =
func smallPowFr*(base: Fr, expo: int): Fr =
if expo >= 0:
return smallPowFr( base, uint(expo) )
else:
@ -233,7 +236,7 @@ proc debugPrintFr*(prefix: string, x: Fr) =
proc debugPrintFrSeq*(msg: string, xs: seq[Fr]) =
echo "---------------------"
echo msg
for x in xs:
for x in xs:
debugPrintFr( " " , x )
proc debugPrintG1*(msg: string, pt: G1) =
@ -249,7 +252,7 @@ proc debugPrintG2*(msg: string, pt: G2) =
#-------------------------------------------------------------------------------
# Montgomery batch inversion
func batchInverse*( xs: seq[Fr] ) : seq[Fr] =
func batchInverse*( xs: seq[Fr] ) : seq[Fr] =
let n = xs.len
assert(n>0)
var us : seq[Fr] = newSeq[Fr](n+1)
@ -269,9 +272,9 @@ proc sanityCheckBatchInverse*() =
let n = xs.len
# for i in 0..<n: echo(i," | batch = ",toDecimalFr(ys[i])," | ref = ",toDecimalFr(zs[i]) )
for i in 0..<n:
if not bool(ys[i] == zs[i]):
if not bool(ys[i] == zs[i]):
echo "batch inverse test FAILED!"
return
return
echo "batch iverse test OK."
#-------------------------------------------------------------------------------
@ -280,7 +283,7 @@ proc sanityCheckBatchInverse*() =
var randomInitialized : bool = false
var randomState : Rand = initRand( 12345 )
proc rndUint64() : uint64 =
proc rndUint64() : uint64 =
return randomState.next()
proc initializeRandomIfNecessary() =
@ -293,15 +296,15 @@ const m64 : B = fromHex( B, "0x000000000000000000000000000000000000000000000001
const m128 : B = fromHex( B, "0x0000000000000000000000000000000100000000000000000000000000000000", bigEndian )
const m192 : B = fromHex( B, "0x0000000000000001000000000000000000000000000000000000000000000000", bigEndian )
#----------------------------| 01234567890abcdf01234567890abcdf01234567890abcdf01234567890abcdf
proc randBig*[bits: static int](): BigInt[bits] =
initializeRandomIfNecessary()
let a0 : uint64 = rndUint64()
let a1 : uint64 = rndUint64()
let a2 : uint64 = rndUint64()
let a3 : uint64 = rndUint64()
let a0 : uint64 = rndUint64()
let a1 : uint64 = rndUint64()
let a2 : uint64 = rndUint64()
let a3 : uint64 = rndUint64()
# echo((a0,a1,a2,a3))
@ -311,16 +314,16 @@ proc randBig*[bits: static int](): BigInt[bits] =
var b3 : BigInt[bits] ; b3.fromUint(a3)
# constantine doesn't appear to have left shift....
var c1,c2,c3 : BigInt[bits]
prod( c1 , b1 , m64 )
var c1,c2,c3 : BigInt[bits]
prod( c1 , b1 , m64 )
prod( c2 , b2 , m128 )
prod( c3 , b3 , m192 )
var d : BigInt[bits]
d = b0
d += c1
d += c2
d += c3
d += c1
d += c2
d += c3
return d
@ -330,7 +333,7 @@ proc randFr*(): Fr =
y.fromBig( b )
return y
proc testRandom*() =
proc testRandom*() =
for i in 1..20:
let x = randFr()
echo(x.toHex())
@ -358,7 +361,7 @@ func checkCurveEqG1*( x, y: Fp ) : bool =
# y^2 = x^3 + B
# B = b1 + bu*u
# b1 = 19485874751759354771024239261021720505790618469301721065564631296452457478373
# b1 = 19485874751759354771024239261021720505790618469301721065564631296452457478373
# b2 = 266929791119991161246907387137283842545076965332900288569378510910307636690
const twistCoeffB_1 : Fp = fromHex(Fp, "0x2b149d40ceb8aaae81be18991be06ac3b5b4c5e559dbefa33267e6dc24a138e5")
const twistCoeffB_u : Fp = fromHex(Fp, "0x009713b03af0fed4cd2cafadeed8fdf4a74fa084e52d1852e4a2bd0685c315d2")
@ -380,14 +383,14 @@ func checkCurveEqG2*( x, y: Fp2 ) : bool =
#-------------------------------------------------------------------------------
func mkG1( x, y: Fp ) : G1 =
func mkG1( x, y: Fp ) : G1 =
if bool(isZero(x)) and bool(isZero(y)):
return infG1
else:
assert( checkCurveEqG1(x,y) , "mkG1: not a G1 curve point" )
return unsafeMkG1(x,y)
func mkG2( x, y: Fp2 ) : G2 =
func mkG2( x, y: Fp2 ) : G2 =
if bool(isZero(x)) and bool(isZero(y)):
return infG2
else:
@ -413,10 +416,10 @@ const gen2* : G2 = unsafeMkG2( gen2_x, gen2_y )
#-------------------------------------------------------------------------------
func isOnCurveG1* ( p: G1 ) : bool =
func isOnCurveG1* ( p: G1 ) : bool =
return checkCurveEqG1( p.x, p.y )
func isOnCurveG2* ( p: G2 ) : bool =
func isOnCurveG2* ( p: G2 ) : bool =
return checkCurveEqG2( p.x, p.y )
#-------------------------------------------------------------------------------
@ -424,27 +427,27 @@ func isOnCurveG2* ( p: G2 ) : bool =
#
# R=2^256; this computes 2^256 mod Fp
func calcFpMontR*() : Fp =
func calcFpMontR*() : Fp =
var x : Fp = intToFp(2)
for i in 1..8:
for i in 1..8:
square(x)
return x
# R=2^256; this computes the inverse of (2^256 mod Fp)
func calcFpInvMontR*() : Fp =
func calcFpInvMontR*() : Fp =
var x : Fp = calcFpMontR()
inv(x)
return x
# R=2^256; this computes 2^256 mod Fr
func calcFrMontR*() : Fr =
func calcFrMontR*() : Fr =
var x : Fr = intToFr(2)
for i in 1..8:
for i in 1..8:
square(x)
return x
# R=2^256; this computes the inverse of (2^256 mod Fp)
func calcFrInvMontR*() : Fr =
func calcFrInvMontR*() : Fr =
var x : Fr = calcFrMontR()
inv(x)
return x
@ -467,20 +470,20 @@ proc checkMontgomeryConstants*() =
#---------------------------------------
# the binary files used by the `circom` ecosystem (EXCEPT the witness file!)
# always use little-endian Montgomery representation. So when we unmarshal
# with Constantine, it will give the wrong result. Calling this function on
# always use little-endian Montgomery representation. So when we unmarshal
# with Constantine, it will give the wrong result. Calling this function on
# the result fixes that.
func fromMontgomeryFp*(x : Fp) : Fp =
func fromMontgomeryFp*(x : Fp) : Fp =
var y : Fp = x;
y *= fpInvMontR
return y
func fromMontgomeryFr*(x : Fr) : Fr =
func fromMontgomeryFr*(x : Fr) : Fr =
var y : Fr = x;
y *= frInvMontR
return y
func toMontgomeryFr*(x : Fr) : Fr =
func toMontgomeryFr*(x : Fr) : Fr =
var y : Fr = x;
y *= frMontR
return y
@ -493,37 +496,37 @@ func toMontgomeryFr*(x : Fr) : Fr =
#
# WTF Jordi, go home you are drunk
func unmarshalFrWTF* ( bs: array[32,byte] ) : Fr =
var big : BigInt[254]
unmarshal( big, bs, littleEndian );
func unmarshalFrWTF* ( bs: array[32,byte] ) : Fr =
var big : BigInt[254]
unmarshal( big, bs, littleEndian );
var x : Fr
x.fromBig( big )
return fromMontgomeryFr(fromMontgomeryFr(x))
func unmarshalFrStd* ( bs: array[32,byte] ) : Fr =
var big : BigInt[254]
unmarshal( big, bs, littleEndian );
func unmarshalFrStd* ( bs: array[32,byte] ) : Fr =
var big : BigInt[254]
unmarshal( big, bs, littleEndian );
var x : Fr
x.fromBig( big )
return x
func unmarshalFpMont* ( bs: array[32,byte] ) : Fp =
var big : BigInt[254]
unmarshal( big, bs, littleEndian );
func unmarshalFpMont* ( bs: array[32,byte] ) : Fp =
var big : BigInt[254]
unmarshal( big, bs, littleEndian );
var x : Fp
x.fromBig( big )
return fromMontgomeryFp(x)
func unmarshalFrMont* ( bs: array[32,byte] ) : Fr =
var big : BigInt[254]
unmarshal( big, bs, littleEndian );
func unmarshalFrMont* ( bs: array[32,byte] ) : Fr =
var big : BigInt[254]
unmarshal( big, bs, littleEndian );
var x : Fr
x.fromBig( big )
return fromMontgomeryFr(x)
#-------------------------------------------------------------------------------
func unmarshalFpMontSeq* ( len: int, bs: openArray[byte] ) : seq[Fp] =
func unmarshalFpMontSeq* ( len: int, bs: openArray[byte] ) : seq[Fp] =
var vals : seq[Fp] = newSeq[Fp]( len )
var bytes : array[32,byte]
for i in 0..<len:
@ -531,7 +534,7 @@ func unmarshalFpMontSeq* ( len: int, bs: openArray[byte] ) : seq[Fp] =
vals[i] = unmarshalFpMont( bytes )
return vals
func unmarshalFrMontSeq* ( len: int, bs: openArray[byte] ) : seq[Fr] =
func unmarshalFrMontSeq* ( len: int, bs: openArray[byte] ) : seq[Fr] =
var vals : seq[Fr] = newSeq[Fr]( len )
var bytes : array[32,byte]
for i in 0..<len:
@ -541,7 +544,7 @@ func unmarshalFrMontSeq* ( len: int, bs: openArray[byte] ) : seq[Fr] =
#-------------------------------------------------------------------------------
proc loadValueFrWTF*( stream: Stream ) : Fr =
proc loadValueFrWTF*( stream: Stream ) : Fr =
var bytes : array[32,byte]
let n = stream.readData( addr(bytes), 32 )
# for i in 0..<32: stdout.write(" " & toHex(bytes[i]))
@ -549,44 +552,44 @@ proc loadValueFrWTF*( stream: Stream ) : Fr =
assert( n == 32 )
return unmarshalFrWTF(bytes)
proc loadValueFrStd*( stream: Stream ) : Fr =
proc loadValueFrStd*( stream: Stream ) : Fr =
var bytes : array[32,byte]
let n = stream.readData( addr(bytes), 32 )
assert( n == 32 )
return unmarshalFrStd(bytes)
proc loadValueFrMont*( stream: Stream ) : Fr =
proc loadValueFrMont*( stream: Stream ) : Fr =
var bytes : array[32,byte]
let n = stream.readData( addr(bytes), 32 )
assert( n == 32 )
return unmarshalFrMont(bytes)
proc loadValueFpMont*( stream: Stream ) : Fp =
proc loadValueFpMont*( stream: Stream ) : Fp =
var bytes : array[32,byte]
let n = stream.readData( addr(bytes), 32 )
assert( n == 32 )
return unmarshalFpMont(bytes)
proc loadValueFp2Mont*( stream: Stream ) : Fp2 =
proc loadValueFp2Mont*( stream: Stream ) : Fp2 =
let i = loadValueFpMont( stream )
let u = loadValueFpMont( stream )
return mkFp2(i,u)
#---------------------------------------
proc loadValuesFrStd*( len: int, stream: Stream ) : seq[Fr] =
proc loadValuesFrStd*( len: int, stream: Stream ) : seq[Fr] =
var values : seq[Fr]
for i in 1..len:
values.add( loadValueFrStd(stream) )
return values
proc loadValuesFpMont*( len: int, stream: Stream ) : seq[Fp] =
proc loadValuesFpMont*( len: int, stream: Stream ) : seq[Fp] =
var values : seq[Fp]
for i in 1..len:
values.add( loadValueFpMont(stream) )
return values
proc loadValuesFrMont*( len: int, stream: Stream ) : seq[Fr] =
proc loadValuesFrMont*( len: int, stream: Stream ) : seq[Fr] =
var values : seq[Fr]
for i in 1..len:
values.add( loadValueFrMont(stream) )
@ -594,25 +597,25 @@ proc loadValuesFrMont*( len: int, stream: Stream ) : seq[Fr] =
#-------------------------------------------------------------------------------
proc loadPointG1*( stream: Stream ) : G1 =
proc loadPointG1*( stream: Stream ) : G1 =
let x = loadValueFpMont( stream )
let y = loadValueFpMont( stream )
return mkG1(x,y)
proc loadPointG2*( stream: Stream ) : G2 =
proc loadPointG2*( stream: Stream ) : G2 =
let x = loadValueFp2Mont( stream )
let y = loadValueFp2Mont( stream )
return mkG2(x,y)
#---------------------------------------
proc loadPointsG1*( len: int, stream: Stream ) : seq[G1] =
proc loadPointsG1*( len: int, stream: Stream ) : seq[G1] =
var points : seq[G1]
for i in 1..len:
points.add( loadPointG1(stream) )
return points
proc loadPointsG2*( len: int, stream: Stream ) : seq[G2] =
proc loadPointsG2*( len: int, stream: Stream ) : seq[G2] =
var points : seq[G2]
for i in 1..len:
points.add( loadPointG2(stream) )
@ -620,30 +623,30 @@ proc loadPointsG2*( len: int, stream: Stream ) : seq[G2] =
#===============================================================================
func addG1*(p,q: G1): G1 =
func addG1*(p,q: G1): G1 =
var r, x, y : ProjG1
prj.fromAffine(x, p)
prj.fromAffine(y, q)
prj.sum(r, x, y)
var s : G1
prj.affine(s, r)
return s
return s
func addG2*(p,q: G2): G2 =
func addG2*(p,q: G2): G2 =
var r, x, y : ProjG2
prj.fromAffine(x, p)
prj.fromAffine(y, q)
prj.sum(r, x, y)
var s : G2
prj.affine(s, r)
return s
return s
func negG1*(p: G1): G1 =
func negG1*(p: G1): G1 =
var r : G1 = p
neg(r)
return r
func negG2*(p: G2): G2 =
func negG2*(p: G2): G2 =
var r : G2 = p
neg(r)
return r
@ -659,8 +662,7 @@ func `-=`*(p: var G2, q: G2) = p = addG2(p,negG2(q))
#-------------------------------------------------------------------------------
#[
func msmG1( coeffs: seq[Fr] , points: seq[G1] ): G1 =
func msmG1*( coeffs: openArray[Fr] , points: openArray[G1] ): G1 =
let N = coeffs.len
assert( N == points.len, "incompatible sequence lengths" )
@ -672,30 +674,53 @@ func msmG1( coeffs: seq[Fr] , points: seq[G1] ): G1 =
for x in coeffs:
bigcfs.add( x.toBig() )
var r : G1
var r : ProjG1
# [Fp,aff.G1]
msm.multiScalarMul_vartime( r,
msm.multiScalarMul_vartime( r,
toOpenArray(bigcfs, 0, N-1),
toOpenArray(points, 0, N-1) )
return r
]#
var rAff: G1
prj.affine(rAff, r)
return rAff
func msmG2*( coeffs: openArray[Fr] , points: openArray[G2] ): G2 =
let N = coeffs.len
assert( N == points.len, "incompatible sequence lengths" )
var bigcfs : seq[BigInt[254]]
for x in coeffs:
bigcfs.add( x.toBig() )
var r : ProjG2
# [Fp,aff.G1]
msm.multiScalarMul_vartime( r,
toOpenArray(bigcfs, 0, N-1),
toOpenArray(points, 0, N-1) )
var rAff: G2
prj.affine(rAff, r)
return rAff
#-------------------------------------------------------------------------------
#
# (affine) scalar multiplication
#
func `**`*( coeff: Fr , point: G1 ) : G1 =
var q : ProjG1
func `**`*( coeff: Fr , point: G1 ) : G1 =
var q : ProjG1
prj.fromAffine( q , point )
scl.scalarMul( q , coeff.toBig() )
var r : G1
prj.affine( r, q )
return r
func `**`*( coeff: Fr , point: G2 ) : G2 =
func `**`*( coeff: Fr , point: G2 ) : G2 =
var q : ProjG2
prj.fromAffine( q , point )
scl.scalarMul( q , coeff.toBig() )
@ -705,15 +730,15 @@ func `**`*( coeff: Fr , point: G2 ) : G2 =
#-------------------
func `**`*( coeff: BigInt , point: G1 ) : G1 =
var q : ProjG1
func `**`*( coeff: BigInt , point: G1 ) : G1 =
var q : ProjG1
prj.fromAffine( q , point )
scl.scalarMul( q , coeff )
var r : G1
prj.affine( r, q )
return r
func `**`*( coeff: BigInt , point: G2 ) : G2 =
func `**`*( coeff: BigInt , point: G2 ) : G2 =
var q : ProjG2
prj.fromAffine( q , point )
scl.scalarMul( q , coeff )
@ -723,10 +748,10 @@ func `**`*( coeff: BigInt , point: G2 ) : G2 =
#-------------------------------------------------------------------------------
func msmNaiveG1( coeffs: seq[Fr] , points: seq[G1] ): G1 =
func msmNaiveG1( coeffs: seq[Fr] , points: seq[G1] ): G1 =
let N = coeffs.len
assert( N == points.len, "incompatible sequence lengths" )
var s : ProjG1
s.setInf()
@ -743,10 +768,10 @@ func msmNaiveG1( coeffs: seq[Fr] , points: seq[G1] ): G1 =
#---------------------------------------
func msmNaiveG2( coeffs: seq[Fr] , points: seq[G2] ): G2 =
func msmNaiveG2( coeffs: seq[Fr] , points: seq[G2] ): G2 =
let N = coeffs.len
assert( N == points.len, "incompatible sequence lengths" )
var s : ProjG2
s.setInf()
@ -763,16 +788,7 @@ func msmNaiveG2( coeffs: seq[Fr] , points: seq[G2] ): G2 =
#-------------------------------------------------------------------------------
# TODO: proper MSM implementation (couldn't make constantine work at first...)
func msmG1*( coeffs: seq[Fr] , points: seq[G1] ): G1 =
return msmNaiveG1( coeffs, points )
func msmG2*( coeffs: seq[Fr] , points: seq[G2] ): G2 =
return msmNaiveG2( coeffs, points )
#-------------------------------------------------------------------------------
proc sanityCheckGroupGen*() =
proc sanityCheckGroupGen*() =
echo( "gen1 on the curve = ", checkCurveEqG1(gen1.x,gen1.y) )
echo( "gen2 on the curve = ", checkCurveEqG2(gen2.x,gen2.y) )
echo( "order of gen1 is R = ", (not bool(isInf(gen1))) and bool(isInf(primeR ** gen1)) )