diff --git a/poseidon2.nim b/poseidon2.nim index 261879b..44b8e58 100644 --- a/poseidon2.nim +++ b/poseidon2.nim @@ -1,55 +1,15 @@ -import - constantine/math/arithmetic, - constantine/math/config/curves +import constantine/math/arithmetic import poseidon2/types -import poseidon2/roundconst +import poseidon2/roundfun import poseidon2/io export toBytes #------------------------------------------------------------------------------- -const zero : F = getZero() - -const externalRoundConst : array[24, F] = arrayFromHex( externalRoundConstStr ) -const internalRoundConst : array[56, F] = arrayFromHex( internalRoundConstStr ) - -#------------------------------------------------------------------------------- - -# inplace sbox, x => x^5 -func sbox(x: var F) : void = - var y = x - square(y) - square(y) - x *= y - -func linearLayer(x, y, z : var F) = - var s = x ; s += y ; s += z - x += s - y += s - z += s - -func internalRound(j: int; x, y, z: var F) = - x += internalRoundConst[j] - sbox(x) - var s = x ; s += y ; s += z - double(z) - x += s - y += s - z += s - -func externalRound(j: int; x, y, z : var F) = - x += externalRoundConst[3*j+0] - y += externalRoundConst[3*j+1] - z += externalRoundConst[3*j+2] - sbox(x) ; sbox(y) ; sbox(z) - var s = x ; s += y ; s += z - x += s - y += s - z += s - -func permInplace*(x, y, z : var F) = +# the Poseidon2 permutation (mutable, in-place version) +proc permInplace*(x, y, z : var F) = linearLayer(x, y, z); for j in 0..3: externalRound(j, x, y, z) @@ -58,6 +18,7 @@ func permInplace*(x, y, z : var F) = for j in 4..7: externalRound(j, x, y, z) +# the Poseidon2 permutation func perm*(xyz: S) : S = var (x,y,z) = xyz permInplace(x, y, z) @@ -65,6 +26,56 @@ func perm*(xyz: S) : S = #------------------------------------------------------------------------------- +# sponge with rate=1 (capacity=2) +func spongeWithRate1*(xs: openArray[F]) : F = + let a = low(xs) + let b = high(xs) + let n = b-a+1 + + var s0 : F = zero + var s1 : F = zero + var s2 : F = zero + + for i in 0.. x^5 +func sbox*(x: var F) : void = + var y = x + square(y) + square(y) + x *= y + +func linearLayer*(x, y, z : var F) = + var s = x ; s += y ; s += z + x += s + y += s + z += s + +func internalRound*(j: int; x, y, z: var F) = + x += internalRoundConst[j] + sbox(x) + var s = x ; s += y ; s += z + double(z) + x += s + y += s + z += s + +func externalRound*(j: int; x, y, z : var F) = + x += externalRoundConst[3*j+0] + y += externalRoundConst[3*j+1] + z += externalRoundConst[3*j+2] + sbox(x) ; sbox(y) ; sbox(z) + var s = x ; s += y ; s += z + x += s + y += s + z += s + +#------------------------------------------------------------------------------- + diff --git a/poseidon2/types.nim b/poseidon2/types.nim index a14d541..093a909 100644 --- a/poseidon2/types.nim +++ b/poseidon2/types.nim @@ -1,8 +1,8 @@ import + constantine/math/arithmetic, constantine/math/io/io_fields, constantine/math/io/io_bigints, - constantine/math/arithmetic, constantine/math/config/curves #------------------------------------------------------------------------------- @@ -18,9 +18,17 @@ func getZero*() : F = setZero(z) return z +func getOne*() : F = + var y : F + # y.fromUint(1'u32) # WTF, why does this not compile ??? + y.fromHex("0x01") + return y + +# for some reason this one does not compile... ??? +# (when actually called) func toF*(a: int) : F = var y : F - fromInt(y, a); + y.fromInt(a) return y func hexToF*(s : string, endian: static Endianness = bigEndian) : F = diff --git a/tests/poseidon2/testPoseidon2.nim b/tests/poseidon2/testPoseidon2.nim index 70e6ea4..f28ef20 100644 --- a/tests/poseidon2/testPoseidon2.nim +++ b/tests/poseidon2/testPoseidon2.nim @@ -10,6 +10,36 @@ import constantine/serialization/codecs import poseidon2/types import poseidon2 +#------------------------------------------------------------------------------- + +const expectedSpongeResultsRate1 : array[8, string] = + [ "12363515589665961836680709257448433057869762330741639517836048636244832188495" + , "10755250120808789043370150604836786069442045362641800439807384337872752972068" + , "04842014531366721455661330916203255410159059117951668762867230544004815370337" + , "13502515636936876459766686836354199651004594178376827739246669803080321705927" + , "19312121576697000598919845239663673946550934099828684806027699882665482322097" + , "21509595983900483103260021285060939918324350560398732346653142062765920502059" + , "11892726572958426459775026381831352388154613015696290329810000571844227402585" + , "10284126944232604349630438079200913190801781418325975675236599364113149409058" + ] + +# TODO: add domain separation between rate=1 and rate=2, so that the empty input +# gives different results. But this has to be done in all the other Poseidon2 libraries +# too (circom, Haskell, C...) + +const expectedSpongeResultsRate2 : array[8, string] = + [ "12363515589665961836680709257448433057869762330741639517836048636244832188495" + , "00899009032366875286186953183805404053380636995610127460025486428583509745414" + , "16500906802543951227422597869354004883060519121579073949799015758201044544012" + , "05275430613748165078459451567241807462288293965310307668712900802458919462965" + , "13763559248248167400098483085605230840597893317332127197498651878933380690961" + , "14871143128308815290845020646262475973102494373985615216162863857354721038367" + , "02746725081632011689597680224823496636241961292066939394613880404914874634920" + , "02290144245981244996669076598332792758523446545263085369617640761875376727694" + ] + +#------------------------------------------------------------------------------- + suite "poseidon2": test "permutation in place": @@ -23,6 +53,24 @@ suite "poseidon2": check toDecimal(y) == "09030699330013392132529464674294378792132780497765201297316864012141442630280" check toDecimal(z) == "09137931384593657624554037900714196568304064431583163402259937475584578975855" + test "sponge with rate=1": + for n in 0..7: + var xs: seq[F] + for i in 1..n: + xs.add( toF(i) ) + let h = spongeWithRate1(xs) + # echo(toDecimal(h)) + check toDecimal(h) == expectedSpongeResultsRate1[n] + + test "sponge with rate=2": + for n in 0..7: + var xs: seq[F] + for i in 1..n: + xs.add( toF(i) ) + let h = spongeWithRate2(xs) + # echo(toDecimal(h)) + check toDecimal(h) == expectedSpongeResultsRate2[n] + test "merkle root of field elements": let m = 17 let n = 2^m