mirror of
https://github.com/logos-storage/plonky2-verifier.git
synced 2026-01-03 22:33:11 +00:00
139 lines
4.1 KiB
Haskell
139 lines
4.1 KiB
Haskell
|
|
-- | Plonky2's Poseidon gate
|
|
--
|
|
|
|
{-# LANGUAGE StrictData, RecordWildCards #-}
|
|
module Gate.Poseidon where
|
|
|
|
--------------------------------------------------------------------------------
|
|
|
|
import Data.Array hiding (range)
|
|
import Data.Char
|
|
import Data.Foldable
|
|
|
|
import Control.Monad
|
|
import Control.Monad.State.Strict
|
|
|
|
import Algebra.Goldilocks
|
|
import Algebra.Expr
|
|
|
|
import Gate.Vars
|
|
import Gate.Computation
|
|
import Hash.Constants
|
|
|
|
import Misc.Aux
|
|
|
|
--------------------------------------------------------------------------------
|
|
|
|
sbox :: Expr Var_ -> Compute (Expr Var_)
|
|
sbox x0 = do
|
|
x <- let_ "x1_" x0
|
|
x2 <- let_ "x2_" (x *x )
|
|
x3 <- let_ "x3_" (x *x2)
|
|
x4 <- let_ "x4_" (x2*x2)
|
|
x7 <- let_ "x7_" (x3*x4)
|
|
return x7
|
|
|
|
flipFoldM :: (Foldable t, Monad m) => b -> t a -> (b -> a -> m b) -> m b
|
|
flipFoldM s0 list action = foldM action s0 list
|
|
|
|
flipFoldM_ :: (Foldable t, Monad m) => b -> t a -> (b -> a -> m b) -> m ()
|
|
flipFoldM_ s0 list action = void (foldM action s0 list)
|
|
|
|
-- | Poseidon state
|
|
type PS = [Expr Var_]
|
|
|
|
--------------------------------------------------------------------------------
|
|
|
|
poseidonGateConstraints :: Compute ()
|
|
poseidonGateConstraints = do
|
|
|
|
-- merkle swap
|
|
let input_lhs i = input i
|
|
let input_rhs i = input (i+4)
|
|
commit $ swap_flag * (swap_flag - 1)
|
|
commitList [ swap_flag * (input_rhs i - input_lhs i) - delta i | i <- range 4 ]
|
|
|
|
-- swapped inputs
|
|
let state0 :: PS
|
|
state0
|
|
= [ input_lhs i + delta i | i <- [0.. 3] ]
|
|
++ [ input_rhs (i-4) - delta (i-4) | i <- [4.. 7] ]
|
|
++ [ input i | i <- [8..11] ]
|
|
|
|
-- initial full rounds
|
|
phase1 <- flipFoldM state0 [0..3] $ \state1 r -> do
|
|
let state2 = plus_rc r state1
|
|
state3 <- if r == 0
|
|
then return state2
|
|
else do
|
|
let sbox_in = initial_sbox_in r
|
|
commitList [ (state2!!i) - sbox_in i | i <- range 12 ]
|
|
return [ sbox_in i | i <- range 12 ]
|
|
state4 <- mapM sbox state3
|
|
return $ mds state4
|
|
|
|
-- partial rounds
|
|
state' <- lets_ $ zipWith (+) phase1 (map LitE $ elems fast_PARTIAL_FIRST_ROUND_CONSTANT)
|
|
state'' <- lets_ $ mdsInitPartial state'
|
|
phase2 <- flipFoldM state'' [0..21] $ \state1 r -> do
|
|
let sbox_in = partial_sbox_in r
|
|
commit $ (state1!!0) - sbox_in
|
|
y <- sbox sbox_in
|
|
let z = if r < 21 then y + LitE (fast_PARTIAL_ROUND_CONSTANTS!r) else y
|
|
state2 <- lets_ (z : tail state1)
|
|
lets_ (mdsFastPartial r state2)
|
|
|
|
-- final full rounds
|
|
phase3 <- flipFoldM phase2 [0..3] $ \state1 r -> do
|
|
let state2 = plus_rc (r+26) state1
|
|
let sbox_in = final_sbox_in r
|
|
commitList [ (state2!!i) - sbox_in i | i <- range 12 ]
|
|
state3 <- mapM sbox [ sbox_in i | i <- range 12 ]
|
|
return $ mds state3
|
|
|
|
-- constraint the output to be the result
|
|
commitList [ phase3!!i - output i | i <- range 12 ]
|
|
|
|
where
|
|
|
|
-- multiply by the MDS matrix
|
|
mds :: PS -> PS
|
|
mds state =
|
|
[ sum [ LitE (mdsMatrixCoeff i j) * x | (j,x) <- zip [0..] state ]
|
|
| i <- range 12
|
|
]
|
|
|
|
dotProd :: PS -> [F] -> Expr Var_
|
|
dotProd es cs = sum $ zipWith (\e c -> e * LitE c) es cs
|
|
|
|
mdsInitPartial :: PS -> PS
|
|
mdsInitPartial (first:rest)
|
|
= first
|
|
: [ sum [ LitE (partialMdsMatrixCoeff i j) * x | (j,x) <- zip [0..] rest ]
|
|
| i <- range 11
|
|
]
|
|
|
|
mdsFastPartial :: Int -> PS -> PS
|
|
mdsFastPartial r state@(s0:rest) = res where
|
|
m0 = mdsMatrixCoeff 0 0
|
|
cs = m0 : elems (fast_PARTIAL_ROUND_W_HATS!r)
|
|
d = dotProd state cs
|
|
res = d : [ x + s0 * LitE t | (x,t) <- zip rest (elems $ fast_PARTIAL_ROUND_VS!r) ]
|
|
|
|
-- add round constants
|
|
plus_rc :: Int -> PS -> PS
|
|
plus_rc r state = zipWith (+) state (map LitE $ elems (all_ROUND_CONSTANTS!r))
|
|
|
|
-- witness variables
|
|
input i = wire i
|
|
output i = wire (i+12)
|
|
swap_flag = wire 24
|
|
delta i = wire (25+i)
|
|
initial_sbox_in r i = wire (29 + 12*(r-1) + i) -- 0 < r < 4 , 0 <= i < 12
|
|
partial_sbox_in r = wire (29 + 36 + r) -- 0 <= r < 22
|
|
final_sbox_in r i = wire (29 + 36 + 22 + 12*r + i) -- 0 <= r < 4 , 0 <= i < 12
|
|
|
|
--------------------------------------------------------------------------------
|
|
|