gnark-rln/rln/poseidon.go

141 lines
3.7 KiB
Go

// forked from https://raw.githubusercontent.com/AlpinYukseloglu/poseidon-gnark/main/circuits/poseidon.go
package rln
import (
"math/big"
"github.com/consensys/gnark/frontend"
)
func Sigma(api frontend.API, in frontend.Variable) frontend.Variable {
return api.Mul(in, in, in, in, in)
}
func Ark(api frontend.API, in []frontend.Variable, c []*big.Int, r int) []frontend.Variable {
for i := range in {
in[i] = api.Add(in[i], c[i+r])
}
return in
}
// Shared logic of multiplication and addition
func multiplyAndAdd(api frontend.API, in []frontend.Variable, factors []*big.Int) frontend.Variable {
result := frontend.Variable(0)
for i, val := range in {
result = api.Add(result, api.Mul(factors[i], val))
}
return result
}
// Helper function to create factors slice for Mix
func createMixFactors(in []frontend.Variable, m [][]*big.Int, index int) []*big.Int {
factors := make([]*big.Int, len(in))
for i := range in {
factors[i] = m[i][index]
}
return factors
}
func Mix(api frontend.API, in []frontend.Variable, m [][]*big.Int) []frontend.Variable {
out := make([]frontend.Variable, len(in))
for i := range in {
out[i] = multiplyAndAdd(api, in, createMixFactors(in, m, i))
}
return out
}
func MixLast(api frontend.API, in []frontend.Variable, m [][]*big.Int, s int) frontend.Variable {
return multiplyAndAdd(api, in, createMixFactors(in, m, s))
}
func MixS(api frontend.API, in []frontend.Variable, s []*big.Int, r int) []frontend.Variable {
out := make([]frontend.Variable, len(in))
for i := range in {
out[0] = api.Add(out[0], api.Mul(s[(len(in)*2-1)*r+i], in[i]))
}
for i := 1; i < len(in); i++ {
out[i] = api.Add(in[i], api.Mul(in[0], s[(len(in)*2-1)*r+len(in)+i-1]))
}
return out
}
func PoseidonEx(api frontend.API, inputs []frontend.Variable, initialState frontend.Variable, nOuts int) []frontend.Variable {
out := make([]frontend.Variable, nOuts)
// Using recommended parameters from whitepaper https://eprint.iacr.org/2019/458.pdf (table 2, table 8)
// Generated by https://extgit.iaik.tugraz.at/krypto/hadeshash/-/blob/master/code/calc_round_numbers.py
// And rounded up to nearest integer that divides by t
nRoundsPC := [16]int{56, 57, 56, 60, 60, 63, 64, 63, 60, 66, 60, 65, 70, 60, 64, 68}
t := len(inputs) + 1
nRoundsF := 8
nRoundsP := nRoundsPC[t-2]
c := POSEIDON_C(t)
s := POSEIDON_S(t)
m := POSEIDON_M(t)
p := POSEIDON_P(t)
state := make([]frontend.Variable, t)
for j := 0; j < t; j++ {
if j == 0 {
state[0] = initialState
} else {
state[j] = inputs[j-1]
}
}
state = Ark(api, state, c, 0)
for r := 0; r < nRoundsF/2-1; r++ {
for j := 0; j < t; j++ {
state[j] = Sigma(api, state[j])
}
state = Ark(api, state, c, (r+1)*t)
state = Mix(api, state, m)
}
for j := 0; j < t; j++ {
state[j] = Sigma(api, state[j])
}
state = Ark(api, state, c, nRoundsF/2*t)
state = Mix(api, state, p)
for r := 0; r < nRoundsP; r++ {
state[0] = Sigma(api, state[0])
state[0] = api.Add(state[0], c[(nRoundsF/2+1)*t+r])
newState0 := frontend.Variable(0)
for j := 0; j < len(state); j++ {
mul := api.Mul(s[(t*2-1)*r+j], state[j])
newState0 = api.Add(newState0, mul)
}
for k := 1; k < t; k++ {
state[k] = api.Add(state[k], api.Mul(state[0], s[(t*2-1)*r+t+k-1]))
}
state[0] = newState0
}
for r := 0; r < nRoundsF/2-1; r++ {
for j := 0; j < t; j++ {
state[j] = Sigma(api, state[j])
}
state = Ark(api, state, c, (nRoundsF/2+1)*t+nRoundsP+r*t)
state = Mix(api, state, m)
}
for j := 0; j < t; j++ {
state[j] = Sigma(api, state[j])
}
for i := 0; i < nOuts; i++ {
out[i] = MixLast(api, state, m, i)
}
return out
}
func Poseidon(api frontend.API, inputs []frontend.Variable) frontend.Variable {
out := PoseidonEx(api, inputs, 0, 1)
return out[0]
}