implement PoseidonMdsGate

This commit is contained in:
Balazs Komuves 2025-01-23 19:41:30 +01:00
parent 644832ec48
commit 8bfe0c6c10
No known key found for this signature in database
GPG Key ID: F63B7AEF18435562
5 changed files with 38 additions and 23 deletions

View File

@ -45,7 +45,7 @@ Supported gates:
- [x] NoopGate - [x] NoopGate
- [x] PublicInputGate - [x] PublicInputGate
- [x] PoseidonGate - [x] PoseidonGate
- [ ] PoseidonMdsGate - [x] PoseidonMdsGate
- [x] RandomAccessGate - [x] RandomAccessGate
- [ ] ReducingGate - [ ] ReducingGate
- [ ] ReducingExtensionGate - [ ] ReducingExtensionGate

View File

@ -72,6 +72,9 @@ lets_ = mapM (let_ "")
commit :: Expr Var_ -> Compute () commit :: Expr Var_ -> Compute ()
commit what = Instr (Commit what) commit what = Instr (Commit what)
commitExt :: Ext (Expr Var_) -> Compute ()
commitExt (MkExt u v) = commit u >> commit v
commitList :: [Expr Var_] -> Compute () commitList :: [Expr Var_] -> Compute ()
commitList = mapM_ commit commitList = mapM_ commit

View File

@ -49,8 +49,7 @@ gateComputation gate =
let j = 8*i let j = 8*i
let c0 = fromBase (cnst 0) let c0 = fromBase (cnst 0)
let c1 = fromBase (cnst 1) let c1 = fromBase (cnst 1)
let MkExt u v = wireExt (j+6) - c0 * wireExt j * wireExt (j+2) - c1 * wireExt (j+4) commitExt $ wireExt (j+6) - c0 * wireExt j * wireExt (j+2) - c1 * wireExt (j+4)
commitList [ u , v ]
-- `sum b^i * limbs[i] - out = 0`, and `0 <= limb[i] < B` is enforced -- `sum b^i * limbs[i] - out = 0`, and `0 <= limb[i] < B` is enforced
BaseSumGate num_limbs base BaseSumGate num_limbs base
@ -79,8 +78,7 @@ gateComputation gate =
MulExtensionGate num_ops MulExtensionGate num_ops
-> forM_ (range num_ops) $ \i -> do -> forM_ (range num_ops) $ \i -> do
let j = 6*i let j = 6*i
let MkExt u v = wireExt (j+4) - fromBase (cnst 0) * wireExt j * wireExt (j+2) commitExt $ wireExt (j+4) - fromBase (cnst 0) * wireExt j * wireExt (j+2)
commitList [ u , v ]
NoopGate -> return () NoopGate -> return ()
@ -93,7 +91,7 @@ gateComputation gate =
k -> error ( "gateConstraints/PoseidonGate: unsupported width " ++ show k) k -> error ( "gateConstraints/PoseidonGate: unsupported width " ++ show k)
PoseidonMdsGate hash_width -> case hash_width of PoseidonMdsGate hash_width -> case hash_width of
12 -> todo -- poseidonMdsGateConstraints 12 -> poseidonMdsGateConstraints
k -> error ( "gateConstraints/PoseidonMdsGate: unsupported width " ++ show k) k -> error ( "gateConstraints/PoseidonMdsGate: unsupported width " ++ show k)
RandomAccessGate num_bits num_copies num_extra_constants RandomAccessGate num_bits num_copies num_extra_constants
@ -137,14 +135,16 @@ exponentiationGateConstraints num_power_bits =
testCompute :: Compute () -> [FExt] testCompute :: Compute () -> [FExt]
testCompute = runComputation testEvaluationVarsExt testCompute = runComputation testEvaluationVarsExt
testArtihExtGate = testCompute $ gateComputation (ArithmeticExtensionGate 10) testArtihExtGate = testCompute $ gateComputation (ArithmeticExtensionGate 10)
testBaseSum2 = testCompute $ gateComputation (BaseSumGate 13 2) testBaseSum2 = testCompute $ gateComputation (BaseSumGate 13 2)
testBaseSum3 = testCompute $ gateComputation (BaseSumGate 13 3) testBaseSum3 = testCompute $ gateComputation (BaseSumGate 13 3)
testExpoGate = testCompute $ gateComputation (ExponentiationGate 13) testExpoGate = testCompute $ gateComputation (ExponentiationGate 13)
testMulExtGate = testCompute $ gateComputation (MulExtensionGate 13) testMulExtGate = testCompute $ gateComputation (MulExtensionGate 13)
testCosetGate3 = testCompute $ cosetInterpolationGateConstraints $ cosetInterpolationGateConfig (Log2 3) testCosetGate3 = testCompute $ cosetInterpolationGateConstraints $ cosetInterpolationGateConfig (Log2 3)
testCosetGate4 = testCompute $ cosetInterpolationGateConstraints $ cosetInterpolationGateConfig (Log2 4) testCosetGate4 = testCompute $ cosetInterpolationGateConstraints $ cosetInterpolationGateConfig (Log2 4)
testCosetGate5 = testCompute $ cosetInterpolationGateConstraints $ cosetInterpolationGateConfig (Log2 5) testCosetGate5 = testCompute $ cosetInterpolationGateConstraints $ cosetInterpolationGateConfig (Log2 5)
testRandAccGate = testCompute $ randomAccessGateConstraints $ randomAccessGateConfig (Log2 4) testRandAccGate = testCompute $ randomAccessGateConstraints $ randomAccessGateConfig (Log2 4)
testPoseidonGate = testCompute $ gateComputation (PoseidonGate 12)
testPoseidonMdsGate = testCompute $ gateComputation (PoseidonMdsGate 12)
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------

View File

@ -51,8 +51,7 @@ calcBarycentricWeights locations = weights where
cosetInterpolationGateConstraints :: CosetInterpolationGateConfig -> Compute () cosetInterpolationGateConstraints :: CosetInterpolationGateConfig -> Compute ()
cosetInterpolationGateConstraints (MkCICfg{..}) = do cosetInterpolationGateConstraints (MkCICfg{..}) = do
let MkExt u v = eval_loc - scaleExt coset_shift shifted_loc commitExt $ eval_loc - scaleExt coset_shift shifted_loc
commitList [ u , v ]
let initials = initial : [ (tmp_eval i , tmp_prod i) | i <- [0..n_intermediates-1] ] let initials = initial : [ (tmp_eval i , tmp_prod i) | i <- [0..n_intermediates-1] ]
let chunks = zip3 chunked_domain chunked_values chunked_weights let chunks = zip3 chunked_domain chunked_values chunked_weights
@ -61,13 +60,11 @@ cosetInterpolationGateConstraints (MkCICfg{..}) = do
let stuff = zipWith worker initials chunks let stuff = zipWith worker initials chunks
forM_ (zip [0..] (init stuff)) $ \(i,(eval,prod)) -> do forM_ (zip [0..] (init stuff)) $ \(i,(eval,prod)) -> do
let MkExt u1 v1 = tmp_eval i - eval commitExt (tmp_eval i - eval)
let MkExt u2 v2 = tmp_prod i - prod commitExt (tmp_prod i - prod)
commitList [ u1 , v1 , u2 , v2 ]
let (final_eval,_) = last stuff let (final_eval,_) = last stuff
let MkExt u v = eval_result - final_eval commitExt (eval_result - final_eval)
commitList [ u , v ]
where where

View File

@ -1,5 +1,5 @@
-- | Plonky2's Poseidon gate -- | Plonky2's Poseidon and PoseidonMds gates
-- --
{-# LANGUAGE StrictData, RecordWildCards #-} {-# LANGUAGE StrictData, RecordWildCards #-}
@ -15,6 +15,7 @@ import Control.Monad
import Control.Monad.State.Strict import Control.Monad.State.Strict
import Algebra.Goldilocks import Algebra.Goldilocks
import Algebra.GoldilocksExt
import Algebra.Expr import Algebra.Expr
import Gate.Vars import Gate.Vars
@ -45,6 +46,20 @@ type PS = [Expr Var_]
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
poseidonMdsGateConstraints :: Compute ()
poseidonMdsGateConstraints = do
forM_ [0..11] $ \i -> do
let result = sum [ scaleExt (LitE (mdsMatrixCoeff i j)) (input j) | j<-[0..11] ] :: Ext (Expr_)
commitExt (output i - result)
where
-- witness variables
input i = wireExt (2* i )
output i = wireExt (2*(i+12))
--------------------------------------------------------------------------------
poseidonGateConstraints :: Compute () poseidonGateConstraints :: Compute ()
poseidonGateConstraints = do poseidonGateConstraints = do