diff --git a/bn128.nim b/bn128.nim index 8465578..be43806 100644 --- a/bn128.nim +++ b/bn128.nim @@ -1,6 +1,8 @@ +import std/bitops import std/strutils import std/streams +import std/random import constantine/math/arithmetic import constantine/math/io/io_fields @@ -8,27 +10,44 @@ import constantine/math/io/io_bigints import constantine/math/config/curves import constantine/math/config/type_ff as tff -import constantine/math/extension_fields/towers as ext -import constantine/math/elliptic/ec_shortweierstrass_affine as ell +import constantine/math/extension_fields/towers as ext +import constantine/math/elliptic/ec_shortweierstrass_affine as aff +import constantine/math/elliptic/ec_shortweierstrass_projective as prj +import constantine/math/pairings/pairings_bn as ate +import constantine/math/elliptic/ec_multi_scalar_mul as msm +import constantine/math/elliptic/ec_scalar_mul as scl #------------------------------------------------------------------------------- -type B* = BigInt[256] -type Fr* = tff.Fr[BN254Snarks] -type Fp* = tff.Fp[BN254Snarks] -type Fp2* = ext.QuadraticExt[Fp] -type G1* = ell.ECP_ShortW_Aff[Fp , ell.G1] -type G2* = ell.ECP_ShortW_Aff[Fp2, ell.G2] +type B* = BigInt[256] +type Fr* = tff.Fr[BN254Snarks] +type Fp* = tff.Fp[BN254Snarks] -func mkFp2(i: Fp, u: Fp) : Fp2 = +type Fp2* = ext.QuadraticExt[Fp] +type Fp12* = ext.Fp12[BN254Snarks] + +type G1* = aff.ECP_ShortW_Aff[Fp , aff.G1] +type G2* = aff.ECP_ShortW_Aff[Fp2, aff.G2] + +type ProjG1* = prj.ECP_ShortW_Prj[Fp , prj.G1] +type ProjG2* = prj.ECP_ShortW_Prj[Fp2, prj.G2] + +func mkFp2* (i: Fp, u: Fp) : Fp2 = let c : array[2, Fp] = [i,u] return ext.QuadraticExt[Fp]( coords: c ) -func unsafeMkG1( X, Y: Fp ) : G1 = - return ell.ECP_ShortW_Aff[Fp, ell.G1](x: X, y: Y) +func unsafeMkG1* ( X, Y: Fp ) : G1 = + return aff.ECP_ShortW_Aff[Fp, aff.G1](x: X, y: Y) -func unsafeMkG2( X, Y: Fp2 ) : G2 = - return ell.ECP_ShortW_Aff[Fp2, ell.G2](x: X, y: Y) +func unsafeMkG2* ( X, Y: Fp2 ) : G2 = + return aff.ECP_ShortW_Aff[Fp2, aff.G2](x: X, y: Y) + +#------------------------------------------------------------------------------- + +func pairing* (p: G1, q: G2) : Fp12 = + var t : Fp12 + pairing_bn( t, p, q ) + return t #------------------------------------------------------------------------------- @@ -37,29 +56,70 @@ const primeR* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d2833e84879b97 #------------------------------------------------------------------------------- -const zeroFp* : Fp = fromHex( Fp, "0x00" ) -const zeroFr* : Fr = fromHex( Fr, "0x00" ) -const oneFp* : Fp = fromHex( Fp, "0x01" ) -const oneFr* : Fr = fromHex( Fr, "0x01" ) +const zeroFp* : Fp = fromHex( Fp, "0x00" ) +const zeroFr* : Fr = fromHex( Fr, "0x00" ) +const oneFp* : Fp = fromHex( Fp, "0x01" ) +const oneFr* : Fr = fromHex( Fr, "0x01" ) const zeroFp2* : Fp2 = mkFp2( zeroFp, zeroFp ) +const oneFp2* : Fp2 = mkFp2( oneFp , zeroFp ) + const infG1* : G1 = unsafeMkG1( zeroFp , zeroFp ) const infG2* : G2 = unsafeMkG2( zeroFp2 , zeroFp2 ) #------------------------------------------------------------------------------- -func intToFp*(a: int) : Fp = +func intToFp*(a: int): Fp = var y : Fp y.fromInt(a) return y -func intToFr*(a: int) : Fr = +func intToFr*(a: int): Fr = var y : Fr y.fromInt(a) return y #------------------------------------------------------------------------------- +func isZeroFp*(x: Fp): bool = bool(isZero(x)) +func isZeroFr*(x: Fr): bool = bool(isZero(x)) + +func isEquaLFp*(x, y: Fp): bool = bool(x == y) +func isEquaLFr*(x, y: Fr): bool = bool(x == y) + +#------------------------------------------------------------------------------- + +func `+`*(x, y: Fp): Fp = ( var z : Fp = x ; z += y ; return z ) +func `+`*(x, y: Fr): Fr = ( var z : Fr = x ; z += y ; return z ) + +func `-`*(x, y: Fp): Fp = ( var z : Fp = x ; z -= y ; return z ) +func `-`*(x, y: Fr): Fr = ( var z : Fr = x ; z -= y ; return z ) + +func `*`*(x, y: Fp): Fp = ( var z : Fp = x ; z *= y ; return z ) +func `*`*(x, y: Fr): Fr = ( var z : Fr = x ; z *= y ; return z ) + +func negFp*(y: Fp): Fp = ( var z : Fp = zeroFp ; z -= y ; return z ) +func negFr*(y: Fr): Fr = ( var z : Fr = zeroFr ; z -= y ; return z ) + +func invFp*(y: Fp): Fp = ( var z : Fp = y ; inv(z) ; return z ) +func invFr*(y: Fr): Fr = ( var z : Fr = y ; inv(z) ; return z ) + +# template/generic instantiation of `pow_vartime` from here +# /Users/bkomuves/.nimble/pkgs/constantine-0.0.1/constantine/math/arithmetic/finite_fields.nim(389, 7) template/generic instantiation of `fieldMod` from here +# /Users/bkomuves/.nimble/pkgs/constantine-0.0.1/constantine/math/config/curves_prop_field_derived.nim(67, 5) Error: undeclared identifier: 'getCurveOrder' +# ... +func smallPowFr*(base: Fr, expo: uint): Fr = + var a : Fr = oneFr + var s : Fr = base + var e : uint = expo + while (e > 0): + if bitand(e,1) > 0: a *= s + e = (e shr 1) + square(s) + return a + +#------------------------------------------------------------------------------- + func toDecimalBig*[n](a : BigInt[n]): string = var s : string = toDecimal(a) s = s.strip( leading=true, trailing=false, chars={'0'} ) @@ -80,16 +140,93 @@ func toDecimalFr*(a : Fr): string = #------------------------------------------------------------------------------- +proc debugPrintSeqFr*(msg: string, xs: seq[Fr]) = + echo "---------------------" + echo msg + for x in xs: + echo(" " & toDecimalFr(x)) + +#------------------------------------------------------------------------------- +# random values + +var randomInitialized : bool = false +var randomState : Rand = initRand( 12345 ) + +proc rndUint64() : uint64 = + return randomState.next() + +proc initializeRandomIfNecessary() = + if not randomInitialized: + randomState = initRand() + randomInitialized = true + +#----------------------------| 01234567890abcdf01234567890abcdf01234567890abcdf01234567890abcdf +const m64 : B = fromHex( B, "0x0000000000000000000000000000000000000000000000010000000000000000", bigEndian ) +const m128 : B = fromHex( B, "0x0000000000000000000000000000000100000000000000000000000000000000", bigEndian ) +const m192 : B = fromHex( B, "0x0000000000000001000000000000000000000000000000000000000000000000", bigEndian ) +#----------------------------| 01234567890abcdf01234567890abcdf01234567890abcdf01234567890abcdf + +proc randBig*[bits: static int](): BigInt[bits] = + + initializeRandomIfNecessary() + + let a0 : uint64 = rndUint64() + let a1 : uint64 = rndUint64() + let a2 : uint64 = rndUint64() + let a3 : uint64 = rndUint64() + + # echo((a0,a1,a2,a3)) + + var b0 : BigInt[bits] ; b0.fromUint(a0) + var b1 : BigInt[bits] ; b1.fromUint(a1) + var b2 : BigInt[bits] ; b2.fromUint(a2) + var b3 : BigInt[bits] ; b3.fromUint(a3) + + # constantine doesn't appear to have left shift.... + var c1,c2,c3 : BigInt[bits] + prod( c1 , b1 , m64 ) + prod( c2 , b2 , m128 ) + prod( c3 , b3 , m192 ) + + var d : BigInt[bits] + d = b0 + d += c1 + d += c2 + d += c3 + + return d + +proc randFr*(): Fr = + let b : BigInt[254] = randBig[254]() + var y : Fr + y.fromBig( b ) + return y + +proc testRandom*() = + for i in 1..20: + let x = randFr() + echo(x.toHex()) + echo("-------------------") + echo(primeR.toHex()) + +#------------------------------------------------------------------------------- + func checkCurveEqG1*( x, y: Fp ) : bool = - var x2 : Fp = x ; square(x2); - var y2 : Fp = y ; square(y2); - var x3 : Fp = x2 ; x3 *= x; - var eq : Fp - eq = x3 - eq += intToFp(3) - eq -= y2 - # echo("eq = ",toDecimalFp(eq)) - return (bool(isZero(eq))) + if bool(isZero(x)) and bool(isZero(y)): + # the point at infinity is on the curve by definition + return true + else: + var x2 : Fp = x ; square(x2); + var y2 : Fp = y ; square(y2); + var x3 : Fp = x2 ; x3 *= x; + var eq : Fp + eq = x3 + eq += intToFp(3) + eq -= y2 + # echo("eq = ",toDecimalFp(eq)) + return (bool(isZero(eq))) + +#--------------------------------------- # y^2 = x^3 + B # B = b1 + bu*u @@ -100,14 +237,18 @@ const twistCoeffB_u : Fp = fromHex(Fp, "0x009713b03af0fed4cd2cafadeed8fdf4a74fa const twistCoeffB : Fp2 = mkFp2( twistCoeffB_1 , twistCoeffB_u ) func checkCurveEqG2*( x, y: Fp2 ) : bool = - var x2 : Fp2 = x ; square(x2); - var y2 : Fp2 = y ; square(y2); - var x3 : Fp2 = x2 ; x3 *= x; - var eq : Fp2 - eq = x3 - eq += twistCoeffB - eq -= y2 - return (bool(isZero(eq))) + if bool(isZero(x)) and bool(isZero(y)): + # the point at infinity is on the curve by definition + return true + else: + var x2 : Fp2 = x ; square(x2); + var y2 : Fp2 = y ; square(y2); + var x3 : Fp2 = x2 ; x3 *= x; + var eq : Fp2 + eq = x3 + eq += twistCoeffB + eq -= y2 + return (bool(isZero(eq))) #------------------------------------------------------------------------------- @@ -125,6 +266,14 @@ func mkG2( x, y: Fp2 ) : G2 = assert( checkCurveEqG2(x,y) , "mkG2: not a G2 curve point" ) return unsafeMkG2(x,y) +#------------------------------------------------------------------------------- + +func isOnCurveG1* ( p: G1 ) : bool = + return checkCurveEqG1( p.x, p.y ) + +func isOnCurveG2* ( p: G2 ) : bool = + return checkCurveEqG2( p.x, p.y ) + #------------------------------------------------------------------------------- # Dealing with Montgomery representation # @@ -281,4 +430,136 @@ proc loadPointsG2*( len: int, stream: Stream ) : seq[G2] = points.add( loadPointG2(stream) ) return points +#=============================================================================== + +func addG1*(p,q: G1): G1 = + var r, x, y : ProjG1 + prj.fromAffine(x, p) + prj.fromAffine(y, q) + prj.sum(r, x, y) + var s : G1 + prj.affine(s, r) + return s + +func addG2*(p,q: G2): G2 = + var r, x, y : ProjG2 + prj.fromAffine(x, p) + prj.fromAffine(y, q) + prj.sum(r, x, y) + var s : G2 + prj.affine(s, r) + return s + +func negG1*(p: G1): G1 = + var r : G1 = p + neg(r) + return r + +func negG2*(p: G2): G2 = + var r : G2 = p + neg(r) + return r + +func `+`*(p,q: G1): G1 = addG1(p,q) +func `+`*(p,q: G2): G2 = addG2(p,q) + +func `+=`*(p: var G1, q: G1) = p = addG1(p,q) +func `+=`*(p: var G2, q: G2) = p = addG2(p,q) + +func `-=`*(p: var G1, q: G1) = p = addG1(p,negG1(q)) +func `-=`*(p: var G2, q: G2) = p = addG2(p,negG2(q)) + +#------------------------------------------------------------------------------- + +# func msmG1( coeffs: seq[Fr] , points: seq[G1] ): G1 = +# +# let N = coeffs.len +# assert( N == points.len, "incompatible sequence lengths" ) +# +# # var arr1 = toOpenArray(coeffs, 0, N-1) +# # var arr2 = toOpenArray(points, 0, N-1) +# +# var bigcfs : seq[BigInt[254]] +# for x in coeffs: +# bigcfs.add( x.toBig() ) +# +# var r : G1 +# +# # [Fp,aff.G1] +# msm.multiScalarMul_vartime( r, +# toOpenArray(bigcfs, 0, N-1), +# toOpenArray(points, 0, N-1) ) +# +# return r + +#------------------------------------------------------------------------------- +# +# (affine) scalar multiplication +# + +func `**`*( coeff: Fr , point: G1 ) : G1 = + var q : ProjG1 + prj.fromAffine( q , point ) + scl.scalarMul( q , coeff.toBig() ) + var r : G1 + prj.affine( r, q ) + return r + +func `**`*( coeff: Fr , point: G2 ) : G2 = + var q : ProjG2 + prj.fromAffine( q , point ) + scl.scalarMul( q , coeff.toBig() ) + var r : G2 + prj.affine( r, q ) + return r + +#------------------------------------------------------------------------------- + +func msmNaiveG1( coeffs: seq[Fr] , points: seq[G1] ): G1 = + let N = coeffs.len + assert( N == points.len, "incompatible sequence lengths" ) + + var s : ProjG1 + s.setInf() + + for i in 0..= 0): d -= 1 + return d + +func polyIsZero*(P: Poly) : bool = + let xs = P.coeffs ; let n = xs.len + var b = true + for i in 0..= m: + for i in 0.. 0: y = cs[0] + for i in 1..= m: + for i in 0..= m: + for i in 0..= i = k - j >= k - min(k , n2-1) + # 0 >= j = k - i >= k - min(k , n1-1) + let A : int = max( 0 , k - min(k , n2-1) ) + let B : int = min( k , n1-1 ) + zs[k] = zeroFr + for i in A..B: + let j = k-i + zs[k] += xs[i] * ys[j] + return Poly(coeffs: zs) + +func polyMul*(P, Q : Poly): Poly = + return polyMulNaive(P, Q) + +#------------------------------------------------------------------------------- + +func `==`*(P, Q: Poly): bool = return polyEqual(P, Q) + +func `+`*(P, Q: Poly): Poly = return polyAdd(P, Q) +func `-`*(P, Q: Poly): Poly = return polySub(P, Q) +func `*`*(P, Q: Poly): Poly = return polyMul(P, Q) + +func `*`*(s: Fr , P: Poly): Poly = return polyScale(s, P) +func `*`*(P: Poly, s: Fr ): Poly = return polyScale(s, P) + +#------------------------------------------------------------------------------- + +# evaluates a polynomial on an FFT domain +func polyForwardNTT*(P: Poly, D: Domain): seq[Fr] = + let n = P.coeffs.len + assert( n <= D.domainSize , "the domain must be as least as big as the polynomial" ) + + if n == D.domainSize: + let src : seq[Fr] = P.coeffs + return forwardNTT(src, D) + else: + var src : seq[Fr] = P.coeffs + for i in n..