diff --git a/bn128.nim b/bn128.nim index b5c9bc0..d36a5fe 100644 --- a/bn128.nim +++ b/bn128.nim @@ -9,789 +9,23 @@ # equation: y^2 = x^3 + 3 # -import sugar - -import std/bitops -import std/strutils -import std/sequtils -import std/streams -import std/random - -import constantine/platforms/abstractions -import constantine/math/isogenies/frobenius - -import constantine/math/arithmetic -import constantine/math/io/io_fields -import constantine/math/io/io_bigints -import constantine/math/config/curves -import constantine/math/config/type_ff as tff - -import constantine/math/extension_fields/towers as ext -import constantine/math/elliptic/ec_shortweierstrass_affine as aff -import constantine/math/elliptic/ec_shortweierstrass_projective as prj -import constantine/math/pairings/pairings_bn as ate -import constantine/math/elliptic/ec_scalar_mul as scl -import constantine/math/elliptic/ec_multi_scalar_mul as msm - #------------------------------------------------------------------------------- -type B* = BigInt[256] -type Fr* = tff.Fr[BN254Snarks] -type Fp* = tff.Fp[BN254Snarks] - -type Fp2* = ext.QuadraticExt[Fp] -type Fp12* = ext.Fp12[BN254Snarks] - -type G1* = aff.ECP_ShortW_Aff[Fp , aff.G1] -type G2* = aff.ECP_ShortW_Aff[Fp2, aff.G2] - -type ProjG1* = prj.ECP_ShortW_Prj[Fp , prj.G1] -type ProjG2* = prj.ECP_ShortW_Prj[Fp2, prj.G2] - -func mkFp2* (i: Fp, u: Fp) : Fp2 = - let c : array[2, Fp] = [i,u] - return ext.QuadraticExt[Fp]( coords: c ) - -func unsafeMkG1* ( X, Y: Fp ) : G1 = - return aff.ECP_ShortW_Aff[Fp, aff.G1](x: X, y: Y) - -func unsafeMkG2* ( X, Y: Fp2 ) : G2 = - return aff.ECP_ShortW_Aff[Fp2, aff.G2](x: X, y: Y) - -#------------------------------------------------------------------------------- - -func pairing* (p: G1, q: G2) : Fp12 = - var t : Fp12 - pairing_bn( t, p, q ) - return t - -#------------------------------------------------------------------------------- - -const primeP* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47", bigEndian ) -const primeR* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", bigEndian ) - -const primeP_254 : BigInt[254] = fromHex( BigInt[254], "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47", bigEndian ) -const primeR_254 : BigInt[254] = fromHex( BigInt[254], "0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", bigEndian ) - -#------------------------------------------------------------------------------- - -const zeroFp* : Fp = fromHex( Fp, "0x00" ) -const zeroFr* : Fr = fromHex( Fr, "0x00" ) -const oneFp* : Fp = fromHex( Fp, "0x01" ) -const oneFr* : Fr = fromHex( Fr, "0x01" ) - -const zeroFp2* : Fp2 = mkFp2( zeroFp, zeroFp ) -const oneFp2* : Fp2 = mkFp2( oneFp , zeroFp ) - -const infG1* : G1 = unsafeMkG1( zeroFp , zeroFp ) -const infG2* : G2 = unsafeMkG2( zeroFp2 , zeroFp2 ) - -#------------------------------------------------------------------------------- - -func intToB*(a: uint): B = - var y : B - y.setUint(a) - return y - -func intToFp*(a: int): Fp = - var y : Fp - y.fromInt(a) - return y - -func intToFr*(a: int): Fr = - var y : Fr - y.fromInt(a) - return y - -#------------------------------------------------------------------------------- - -func isZeroFp*(x: Fp): bool = bool(isZero(x)) -func isZeroFr*(x: Fr): bool = bool(isZero(x)) - -func isEqualFp*(x, y: Fp): bool = bool(x == y) -func isEqualFr*(x, y: Fr): bool = bool(x == y) - -func `===`*(x, y: Fp): bool = isEqualFp(x,y) -func `===`*(x, y: Fr): bool = isEqualFr(x,y) +import ./bn128/fields +import ./bn128/curves +import ./bn128/msm +import ./bn128/io +import ./bn128/rnd +import ./bn128/debug #------------------- -func isEqualFpSeq*(xs, ys: seq[Fp]): bool = - let n = xs.len - assert( n == ys.len ) - var b = true - for i in 0.. 0): - if bitand(e,1) > 0: a *= s - e = (e shr 1) - square(s) - return a - -func smallPowFr*(base: Fr, expo: int): Fr = - if expo >= 0: - return smallPowFr( base, uint(expo) ) - else: - return smallPowFr( invFr(base) , uint(-expo) ) - -#------------------------------------------------------------------------------- - -func deltaFr*(i, j: int) : Fr = - return (if (i == j): oneFr else: zeroFr) - -#------------------------------------------------------------------------------- - -func toDecimalBig*[n](a : BigInt[n]): string = - var s : string = toDecimal(a) - s = s.strip( leading=true, trailing=false, chars={'0'} ) - if s.len == 0: s="0" - return s - -func toDecimalFp*(a : Fp): string = - var s : string = toDecimal(a) - s = s.strip( leading=true, trailing=false, chars={'0'} ) - if s.len == 0: s="0" - return s - -func toDecimalFr*(a : Fr): string = - var s : string = toDecimal(a) - s = s.strip( leading=true, trailing=false, chars={'0'} ) - if s.len == 0: s="0" - return s - -#--------------------------------------- - -const k65536 : BigInt[254] = fromHex( BigInt[254], "0x10000", bigEndian ) - -func signedToDecimalFp*(a : Fp): string = - if bool( a.toBig() > primeP_254 - k65536 ): - return "-" & toDecimalFp(negFp(a)) - else: - return toDecimalFp(a) - -func signedToDecimalFr*(a : Fr): string = - if bool( a.toBig() > primeR_254 - k65536 ): - return "-" & toDecimalFr(negFr(a)) - else: - return toDecimalFr(a) - -#------------------------------------------------------------------------------- - -proc debugPrintFp*(prefix: string, x: Fp) = - echo(prefix & toDecimalFp(x)) - -proc debugPrintFp2*(prefix: string, z: Fp2) = - echo(prefix & " 1 ~> " & toDecimalFp(z.coords[0])) - echo(prefix & " u ~> " & toDecimalFp(z.coords[1])) - -proc debugPrintFr*(prefix: string, x: Fr) = - echo(prefix & toDecimalFr(x)) - -proc debugPrintFrSeq*(msg: string, xs: seq[Fr]) = - echo "---------------------" - echo msg - for x in xs: - debugPrintFr( " " , x ) - -proc debugPrintG1*(msg: string, pt: G1) = - echo(msg & ":") - debugPrintFp( " x = ", pt.x ) - debugPrintFp( " y = ", pt.y ) - -proc debugPrintG2*(msg: string, pt: G2) = - echo(msg & ":") - debugPrintFp2( " x = ", pt.x ) - debugPrintFp2( " y = ", pt.y ) - -#------------------------------------------------------------------------------- - -# Montgomery batch inversion -func batchInverse*( xs: seq[Fr] ) : seq[Fr] = - let n = xs.len - assert(n>0) - var us : seq[Fr] = newSeq[Fr](n+1) - var a = xs[0] - us[0] = oneFr - us[1] = a - for i in 1.. " & toDecimalFp(z.coords[0])) + echo(prefix & " u ~> " & toDecimalFp(z.coords[1])) + +proc debugPrintFr*(prefix: string, x: Fr) = + echo(prefix & toDecimalFr(x)) + +proc debugPrintFrSeq*(msg: string, xs: seq[Fr]) = + echo "---------------------" + echo msg + for x in xs: + debugPrintFr( " " , x ) + +proc debugPrintG1*(msg: string, pt: G1) = + echo(msg & ":") + debugPrintFp( " x = ", pt.x ) + debugPrintFp( " y = ", pt.y ) + +proc debugPrintG2*(msg: string, pt: G2) = + echo(msg & ":") + debugPrintFp2( " x = ", pt.x ) + debugPrintFp2( " y = ", pt.y ) + +#------------------------------------------------------------------------------- + diff --git a/bn128/fields.nim b/bn128/fields.nim new file mode 100644 index 0000000..84d4805 --- /dev/null +++ b/bn128/fields.nim @@ -0,0 +1,186 @@ + +# +# the prime fields Fp and Fr with sizes +# +# p = 21888242871839275222246405745257275088696311157297823662689037894645226208583 +# r = 21888242871839275222246405745257275088548364400416034343698204186575808495617 +# + +import sugar + +import std/bitops +import std/sequtils + +import constantine/math/arithmetic +import constantine/math/io/io_fields +import constantine/math/io/io_bigints +import constantine/math/config/curves +import constantine/math/config/type_ff as tff +import constantine/math/extension_fields/towers as ext + +#------------------------------------------------------------------------------- + +type B* = BigInt[256] +type Fr* = tff.Fr[BN254Snarks] +type Fp* = tff.Fp[BN254Snarks] + +type Fp2* = ext.QuadraticExt[Fp] +type Fp12* = ext.Fp12[BN254Snarks] + +func mkFp2* (i: Fp, u: Fp) : Fp2 = + let c : array[2, Fp] = [i,u] + return ext.QuadraticExt[Fp]( coords: c ) + +#------------------------------------------------------------------------------- + +const primeP* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47", bigEndian ) +const primeR* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", bigEndian ) + +#------------------------------------------------------------------------------- + +const zeroFp* : Fp = fromHex( Fp, "0x00" ) +const zeroFr* : Fr = fromHex( Fr, "0x00" ) +const oneFp* : Fp = fromHex( Fp, "0x01" ) +const oneFr* : Fr = fromHex( Fr, "0x01" ) + +const zeroFp2* : Fp2 = mkFp2( zeroFp, zeroFp ) +const oneFp2* : Fp2 = mkFp2( oneFp , zeroFp ) + +#------------------------------------------------------------------------------- + +func intToB*(a: uint): B = + var y : B + y.setUint(a) + return y + +func intToFp*(a: int): Fp = + var y : Fp + y.fromInt(a) + return y + +func intToFr*(a: int): Fr = + var y : Fr + y.fromInt(a) + return y + +#------------------------------------------------------------------------------- + +func isZeroFp* (x: Fp ): bool = bool(isZero(x)) +func isZeroFp2*(x: Fp2): bool = bool(isZero(x)) +func isZeroFr* (x: Fr ): bool = bool(isZero(x)) + +func isEqualFp* (x, y: Fp ): bool = bool(x == y) +func isEqualFp2*(x, y: Fp2): bool = bool(x == y) +func isEqualFr* (x, y: Fr ): bool = bool(x == y) + +func `===`*(x, y: Fp ): bool = isEqualFp(x,y) +func `===`*(x, y: Fp2): bool = isEqualFp2(x,y) +func `===`*(x, y: Fr ): bool = isEqualFr(x,y) + +#------------------- + +func isEqualFpSeq*(xs, ys: seq[Fp]): bool = + let n = xs.len + assert( n == ys.len ) + var b = true + for i in 0.. 0): + if bitand(e,1) > 0: a *= s + e = (e shr 1) + square(s) + return a + +func smallPowFr*(base: Fr, expo: int): Fr = + if expo >= 0: + return smallPowFr( base, uint(expo) ) + else: + return smallPowFr( invFr(base) , uint(-expo) ) + +#------------------------------------------------------------------------------- + +func deltaFr*(i, j: int) : Fr = + return (if (i == j): oneFr else: zeroFr) + +#------------------------------------------------------------------------------- + +# Montgomery batch inversion +func batchInverseFr*( xs: seq[Fr] ) : seq[Fr] = + let n = xs.len + assert(n>0) + var us : seq[Fr] = newSeq[Fr](n+1) + var a = xs[0] + us[0] = oneFr + us[1] = a + for i in 1.. primeP_254 - k65536 ): + return "-" & toDecimalFp(negFp(a)) + else: + return toDecimalFp(a) + +func signedToDecimalFr*(a : Fr): string = + if bool( a.toBig() > primeR_254 - k65536 ): + return "-" & toDecimalFr(negFr(a)) + else: + return toDecimalFr(a) + +#------------------------------------------------------------------------------- +# Dealing with Montgomery representation +# + +# R=2^256; this computes 2^256 mod Fp +func calcFpMontR*() : Fp = + var x : Fp = intToFp(2) + for i in 1..8: + square(x) + return x + +# R=2^256; this computes the inverse of (2^256 mod Fp) +func calcFpInvMontR*() : Fp = + var x : Fp = calcFpMontR() + inv(x) + return x + +# R=2^256; this computes 2^256 mod Fr +func calcFrMontR*() : Fr = + var x : Fr = intToFr(2) + for i in 1..8: + square(x) + return x + +# R=2^256; this computes the inverse of (2^256 mod Fp) +func calcFrInvMontR*() : Fr = + var x : Fr = calcFrMontR() + inv(x) + return x + +# apparently we cannot compute these in compile time for some reason or other... (maybe because `intToFp()`?) +const fpMontR* : Fp = fromHex( Fp, "0x0e0a77c19a07df2f666ea36f7879462c0a78eb28f5c70b3dd35d438dc58f0d9d" ) +const fpInvMontR* : Fp = fromHex( Fp, "0x2e67157159e5c639cf63e9cfb74492d9eb2022850278edf8ed84884a014afa37" ) + +# apparently we cannot compute these in compile time for some reason or other... (maybe because `intToFp()`?) +const frMontR* : Fr = fromHex( Fr, "0x0e0a77c19a07df2f666ea36f7879462e36fc76959f60cd29ac96341c4ffffffb" ) +const frInvMontR* : Fr = fromHex( Fr, "0x15ebf95182c5551cc8260de4aeb85d5d090ef5a9e111ec87dc5ba0056db1194e" ) + +proc checkMontgomeryConstants*() = + assert( bool( fpMontR == calcFpMontR() ) ) + assert( bool( frMontR == calcFrMontR() ) ) + assert( bool( fpInvMontR == calcFpInvMontR() ) ) + assert( bool( frInvMontR == calcFrInvMontR() ) ) + echo("OK") + +#--------------------------------------- + +# the binary file `.zkey` used by the `circom` ecosystem uses little-endian +# Montgomery representation. So when we unmarshal with Constantine, it will +# give the wrong result. Calling this function on the result fixes that. +func fromMontgomeryFp*(x : Fp) : Fp = + var y : Fp = x; + y *= fpInvMontR + return y + +func fromMontgomeryFr*(x : Fr) : Fr = + var y : Fr = x; + y *= frInvMontR + return y + +func toMontgomeryFr*(x : Fr) : Fr = + var y : Fr = x; + y *= frMontR + return y + +#------------------------------------------------------------------------------- +# Unmarshalling field elements +# Note: in the `.zkey` coefficients, e apparently DOUBLE Montgomery encoding is used ?!? +# + +func unmarshalFpMont* ( bs: array[32,byte] ) : Fp = + var big : BigInt[254] + unmarshal( big, bs, littleEndian ); + var x : Fp + x.fromBig( big ) + return fromMontgomeryFp(x) + +# WTF Jordi, go home you are drunk +func unmarshalFrWTF* ( bs: array[32,byte] ) : Fr = + var big : BigInt[254] + unmarshal( big, bs, littleEndian ); + var x : Fr + x.fromBig( big ) + return fromMontgomeryFr(fromMontgomeryFr(x)) + +func unmarshalFrStd* ( bs: array[32,byte] ) : Fr = + var big : BigInt[254] + unmarshal( big, bs, littleEndian ); + var x : Fr + x.fromBig( big ) + return x + +func unmarshalFrMont* ( bs: array[32,byte] ) : Fr = + var big : BigInt[254] + unmarshal( big, bs, littleEndian ); + var x : Fr + x.fromBig( big ) + return fromMontgomeryFr(x) + +#------------------------------------------------------------------------------- + +func unmarshalFpMontSeq* ( len: int, bs: openArray[byte] ) : seq[Fp] = + var vals : seq[Fp] = newSeq[Fp]( len ) + var bytes : array[32,byte] + for i in 0..= 1024, we use 8 threads + let nthreads = max( 1 , min( N div 128 , nthreadsTarget ) ) + + let m = N div nthreads + + var threads : seq[Thread[InputTuple]] = newSeq[Thread[InputTuple]]( nthreads ) + var results : seq[G1] = newSeq[G1]( nthreads ) + + proc myThreadFunc( inp: InputTuple ) {.thread.} = + results[inp.idx] = msmConstantineG1( inp.coeffs, inp.points ) + + for i in 0..