diff --git a/groth16/bn128/curves.nim b/groth16/bn128/curves.nim index f162875..aeafe60 100644 --- a/groth16/bn128/curves.nim +++ b/groth16/bn128/curves.nim @@ -9,6 +9,8 @@ # equation: y^2 = x^3 + 3 # +import std/bitops +import std/options #import constantine/platforms/abstractions #import constantine/math/isogenies/frobenius @@ -25,6 +27,9 @@ import constantine/math/elliptic/ec_shortweierstrass_projective as prj import constantine/math/pairings/pairings_bn as ate import constantine/math/elliptic/ec_scalar_mul_vartime as scl +import constantine/math/arithmetic/finite_fields_square_root as sqrt +import constantine/math/extension_fields/square_root_fp2 as sqrt2 + import groth16/bn128/fields #------------------------------------------------------------------------------- @@ -35,6 +40,15 @@ type G2* = aff.EC_ShortW_Aff[Fp2[BN254_Snarks], aff.G2] type ProjG1* = prj.EC_ShortW_Prj[Fp[BN254_Snarks] , prj.G1] type ProjG2* = prj.EC_ShortW_Prj[Fp2[BN254_Snarks], prj.G2] +#------------------------------------------------------------------------------- +# compressed points (supposedly compatible with arkworks-0.5) + +type ComprG1* = distinct array[32, byte]; +type ComprG2* = distinct array[64, byte]; + +proc `==` *(a, b: ComprG1): bool {.borrow.} +proc `==` *(a, b: ComprG2): bool {.borrow.} + #------------------------------------------------------------------------------- func isEqualG1* (x, y: G1 ): bool = bool(x == y) @@ -64,6 +78,9 @@ func isInfProjG2*(pt : ProjG2): bool = bool(isNeutral(pt)) #------------------------------------------------------------------------------- +# y^2 = x^3 + B where B = 3 +const theCoeffB = fromHex(Fp[BN254_Snarks], "0x0000000000000000000000000000000000000000000000000000000000000003") + func checkCurveEqG1*( x, y: Fp[BN254_Snarks] ) : bool = if bool(isZero(x)) and bool(isZero(y)): # the point at infinity is on the curve by definition @@ -74,7 +91,7 @@ func checkCurveEqG1*( x, y: Fp[BN254_Snarks] ) : bool = var x3 = x2 * x var eq : Fp[BN254_Snarks] eq = x3 - eq += intToFp(3) + eq += theCoeffB eq -= y2 # echo("eq = ",toDecimalFp(eq)) return (bool(isZero(eq))) @@ -182,6 +199,113 @@ func isInSubgroupG1* ( p: G1 ) : bool = func isInSubgroupG2* ( p: G2 ) : bool = return checkSubgroupG2( p.x, p.y ) +#------------------------------------------------------------------------------- + +func unwrapComprG1*( c1: ComprG1 ): array[32,byte] = + return array[32,byte](c1) + +func unwrapComprG2*( c2: ComprG2 ): array[64,byte] = + return array[64,byte](c2) + +func bigInt256_to_254(inp: BigInt[256]): BigInt[254] = + var res: BigInt[254] + res.copyTruncatedFrom(inp) + return res + +const halfPrime256 : BigInt[256] = fromHex( B, "0x183227397098d014dc2822db40c0ac2ecbc0b548b438e5469e10460b6c3e7ea3", bigEndian ) +const halfPrime254 : BigInt[254] = bigInt256_to_254( halfPrime256 ) +const thePrime254 : BigInt[254] = bigInt256_to_254( primeP ) + +# little-endian encoding of the X coord, with bit 255 set if `Y > P/2` +func compressG1*( pt : G1 ) : ComprG1 = + var xbig : BigInt[254] + var ybig : BigInt[254] + xbig.fromField( pt.x ) + ybig.fromField( pt.y ) + let flag : bool = bool(ybig > halfPrime254) + var buf : array[32,byte] + buf.marshal(xbig, littleEndian) + if (flag): + buf[31] = bitor( buf[31] , 0x80 ); + return ComprG1(buf) + +func uncompressG1*( compr1 : ComprG1 ) : Option[G1] = + var buf : array[32,byte] = unwrapComprG1(compr1) + let flag : bool = (buf[31] >= 0x80) + buf[31] = bitand( buf[31] , 0x7f ) + var xbig : BigInt[254] + xbig.unmarshal(buf, littleEndian) + if bool(xbig >= thePrime254): + return none(G1) + else: + var x : Fp[BN254_Snarks] + var y : Fp[BN254_Snarks] + x.fromBig(xbig) + y = x*x*x + theCoeffB + let ok = bool( sqrt.sqrt_if_square_vartime(y) ) + if ok: + var ybig : BigInt[254] + ybig.fromField( y ) + let switch = bool(ybig > halfPrime254) xor flag + if switch: + y.neg() + let g1 = unsafeMkG1(x,y) + return some(g1) + else: + return none(G1) + +#--------------------------------------- + +# little-endian encoding of the X coord, with bit 255 set if `Y_imag > P/2` +func compressG2*( pt : G2 ) : ComprG2 = + var x_real_big : BigInt[254] + var x_imag_big : BigInt[254] + var y_imag_big : BigInt[254] + x_real_big.fromField( pt.x.coords[0] ) + x_imag_big.fromField( pt.x.coords[1] ) + y_imag_big.fromField( pt.y.coords[1] ) + let flag : bool = bool(y_imag_big > halfPrime254) + var buf_real : array[32,byte] + var buf_imag : array[32,byte] + marshal(buf_real , x_real_big , littleEndian) + marshal(buf_imag , x_imag_big , littleEndian) + var buf: array[64,byte] + buf[ 0..31] = buf_real + buf[32..63] = buf_imag + if (flag): + buf[63] = bitor( buf[63] , 0x80 ); + return ComprG2(buf) + +func uncompressG2*( compr2 : ComprG2 ) : Option[G2] = + var buf : array[64,byte] = unwrapComprG2(compr2) + let flag : bool = (buf[63] >= 0x80) + buf[63] = bitand( buf[63] , 0x7f ) + var x_big_real : BigInt[254] + var x_big_imag : BigInt[254] + unmarshal(x_big_real , buf , littleEndian) + unmarshal(x_big_imag , buf[32..63] , littleEndian) + if bool(x_big_real >= thePrime254) or bool(x_big_imag >= thePrime254): + return none(G2) + else: + var x_real : Fp[BN254_Snarks] + var x_imag : Fp[BN254_Snarks] + var y : Fp2[BN254_Snarks] + x_real.fromBig(x_big_real) + x_imag.fromBig(x_big_imag) + let x: Fp2[BN254_Snarks] = mkFp2( x_real, x_imag ) + y = x*x*x + twistCoeffB + let ok = bool( sqrt2.sqrt_if_square(y) ) + if ok: + var y_big_imag : BigInt[254] + y_big_imag.fromField( y.coords[1] ) + let switch = bool(y_big_imag > halfPrime254) xor flag + if switch: + y.neg() + let g2 = unsafeMkG2(x,y) + return some(g2) + else: + return none(G2) + #=============================================================================== func addG1*(p,q: G1): G1 = diff --git a/groth16/bn128/fields.nim b/groth16/bn128/fields.nim index 97f1c50..abd52be 100644 --- a/groth16/bn128/fields.nim +++ b/groth16/bn128/fields.nim @@ -27,8 +27,9 @@ func mkFp2* (i: Fp[BN254_Snarks], u: Fp[BN254_Snarks]) : Fp2[BN254_Snarks] = #------------------------------------------------------------------------------- -const primeP* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47", bigEndian ) -const primeR* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", bigEndian ) +const primeR* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", bigEndian ) +const primeP* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47", bigEndian ) +const halfPrimeP* : B = fromHex( B, "0x183227397098d014dc2822db40c0ac2ecbc0b548b438e5469e10460b6c3e7ea3", bigEndian ) #------------------------------------------------------------------------------- diff --git a/groth16/bn128/rnd.nim b/groth16/bn128/rnd.nim index f2f3607..fec88c4 100644 --- a/groth16/bn128/rnd.nim +++ b/groth16/bn128/rnd.nim @@ -9,6 +9,7 @@ import constantine/math/io/io_bigints import constantine/named/properties_fields import groth16/bn128/fields +import groth16/bn128/curves #------------------------------------------------------------------------------- # random values @@ -66,6 +67,18 @@ proc randFr*(): Fr[BN254_Snarks] = y.fromBig( b ) return y +#------------------------------------------------------------------------------- + +proc randG1*(): G1 = + let expo : BigInt[254] = randBig[254]() + return (expo ** gen1) + +proc randG2*(): G2 = + let expo : BigInt[254] = randBig[254]() + return (expo ** gen2) + +#------------------------------------------------------------------------------- + proc testRandom*() = for i in 1..20: let x = randFr() diff --git a/tests/.gitignore b/tests/.gitignore new file mode 100644 index 0000000..30d74d2 --- /dev/null +++ b/tests/.gitignore @@ -0,0 +1 @@ +test \ No newline at end of file diff --git a/tests/groth16/testCurve.nim b/tests/groth16/testCurve.nim index 9b2f489..9c1667d 100644 --- a/tests/groth16/testCurve.nim +++ b/tests/groth16/testCurve.nim @@ -29,7 +29,7 @@ const pt2_yu = fromHex(Fp[BN254_Snarks], "0x04f21f9d99cc25f694cf22ff70dc0ac4692 const pt2_x = mkFp2( pt2_x1, pt2_xu ) const pt2_y = mkFp2( pt2_y1, pt2_yu ) -suite "curves": +suite "curve and subgroup checks": test "gen1 is on the curve": check checkCurveEqG1(gen1.x,gen1.y) diff --git a/tests/groth16/testPtCompression.nim b/tests/groth16/testPtCompression.nim new file mode 100644 index 0000000..9fad92f --- /dev/null +++ b/tests/groth16/testPtCompression.nim @@ -0,0 +1,137 @@ + +{.used.} + +import std/unittest +import std/options +import std/strutils + +import constantine/math/io/io_fields +import constantine/named/properties_fields + +import groth16/bn128/fields +import groth16/bn128/curves +import groth16/bn128/rnd + +import groth16/bn128/debug + +#------------------------------------------------------------------------------- +# compression test cases generated by arkworks + +# last bit is NOT SET +const testA_compressed = "fbc1b30a5acf74d512a3db0ab849bf095c106e5f52cf289c05b58dbb949b2824a156d453ec510b7faad49260faa23781e0962dbb5bbd419753896d42672da717" +const testA_g2_x_real = fromHex(Fp[BN254_Snarks], "0x24289b94bb8db5059c28cf525f6e105c09bf49b80adba312d574cf5a0ab3c1fb") +const testA_g2_x_imag = fromHex(Fp[BN254_Snarks], "0x17a72d67426d89539741bd5bbb2d96e08137a2fa6092d4aa7f0b51ec53d456a1") +const testA_g2_y_real = fromHex(Fp[BN254_Snarks], "0x1c596fe08af669b99b08a7198a94b8abc59e711a5ba88f84b3670aa3da0775f1") +const testA_g2_y_imag = fromHex(Fp[BN254_Snarks], "0x1166e754640ae7db87c1ad56886af9270bed8afd813922628fdd700e36048f09") + +# last bit is SET +const testB_compressed = "979a0fece9f1d92ac5889660f19370145ede8269fbd483714ec0517f76f3c51ced51ff0e98cfb98d94dbade55df493cd57f6af07c60b5e58ce8de13ceac68b9d" +const testB_g2_x_real = fromHex(Fp[BN254_Snarks], "0x1cc5f3767f51c04e7183d4fb6982de5e147093f1609688c52ad9f1e9ec0f9a97") +const testB_g2_x_imag = fromHex(Fp[BN254_Snarks], "0x1d8bc6ea3ce18dce585e0bc607aff657cd93f45de5addb948db9cf980eff51ed") +const testB_g2_y_real = fromHex(Fp[BN254_Snarks], "0x0d034c3de83b9cb8fb066a360afe5391c7e170efc6ebe6d4b93f252126ac204d") +const testB_g2_y_imag = fromHex(Fp[BN254_Snarks], "0x2a7090cf51be2141e049d0176e744fa0420099090b636984bfaa1142456f4b3a") + +func hexStringToByteSeq(hex: string): seq[byte] = + let s = parseHexStr(hex) + result = newSeq[byte](s.len) + for i, ch in s: + result[i] = byte(ch) + +func hexStringToComprG2(hex: string): ComprG2 = + let bseq = hexStringToByteSeq(hex) + var arr: array[64,byte] + for i, b in bseq: + arr[i] = b + return ComprG2(arr) + +#------------------------------------------------------------------------------- + +func mbEqualsG1(mb: Option[G1] , refVal: G1): bool = + var ok: bool = false + if isSome(mb): + let re : G1 = mb.unsafeGet() + ok = (refVal === re) + return ok + +func mbEqualsG2(mb: Option[G2] , refVal: G2): bool = + var ok: bool = false + if isSome(mb): + let re : G2 = mb.unsafeGet() + ok = (refVal === re) + return ok + +#------------------------------------------------------------------------------- + +suite "point compression": + + test "unit test for G2 point compression, test case `A` (flag is not set)": + + let x = mkFp2( testA_g2_x_real , testA_g2_x_imag ) + let y = mkFp2( testA_g2_y_real , testA_g2_y_imag ) + let pt = mkG2( x , y ) + let c = hexStringToComprG2(testA_compressed) + + let ok1 = (compressG2(pt) == c) + let ok2 = mbEqualsG2( uncompressG2(c) , pt ) + + check (ok1 and ok2) + + #--------------------------- + + test "unit test for G2 point compression, test case `B` (flag is set)": + + let x = mkFp2( testB_g2_x_real , testB_g2_x_imag ) + let y = mkFp2( testB_g2_y_real , testB_g2_y_imag ) + let pt = mkG2( x , y ) + let c = hexStringToComprG2(testB_compressed) + + let ok1 = (compressG2(pt) == c) + let ok2 = mbEqualsG2( uncompressG2(c) , pt ) + + check (ok1 and ok2) + + #--------------------------- + + test "test G1 point compression and decompression for 500 random points": + + let N = 500 + var ok = true + var cnt = 0 + + for i in 1..N: + let pt : G1 = randG1() + let c : ComprG1 = compressG1(pt) + let mb : Option[G1] = uncompressG1(c) + + var this_ok = mbEqualsG1(mb, pt) + + ok = ok and this_ok + if this_ok: + cnt += 1 + + echo "out of " & $N & " random tests, " & $cnt & " passed" + check ok + + #--------------------------- + + test "test G2 point compression and decompression for 500 random points": + + let N = 500 + var ok = true + var cnt = 0 + + for i in 1..N: + let pt : G2 = randG2() + let c : ComprG2 = compressG2(pt) + let mb : Option[G2] = uncompressG2(c) + + var this_ok = mbEqualsG2(mb, pt) + + ok = ok and this_ok + if this_ok: + cnt += 1 + + echo "out of " & $N & " random tests, " & $cnt & " passed" + check ok + +#------------------------------------------------------------------------------- diff --git a/tests/test.nim b/tests/test.nim index a4e5e02..36aea78 100644 --- a/tests/test.nim +++ b/tests/test.nim @@ -1,4 +1,5 @@ +import ./groth16/testPtCompression import ./groth16/testCurve import ./groth16/testProver