From 8bfe0c6c103edf9d066b550689414bdbdbeb8d77 Mon Sep 17 00:00:00 2001 From: Balazs Komuves Date: Thu, 23 Jan 2025 19:41:30 +0100 Subject: [PATCH] implement PoseidonMdsGate --- README.md | 2 +- src/Gate/Computation.hs | 3 +++ src/Gate/Constraints.hs | 28 ++++++++++++++-------------- src/Gate/Custom/CosetInterp.hs | 11 ++++------- src/Gate/Custom/Poseidon.hs | 17 ++++++++++++++++- 5 files changed, 38 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index 39c950a..24a7f94 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ Supported gates: - [x] NoopGate - [x] PublicInputGate - [x] PoseidonGate -- [ ] PoseidonMdsGate +- [x] PoseidonMdsGate - [x] RandomAccessGate - [ ] ReducingGate - [ ] ReducingExtensionGate diff --git a/src/Gate/Computation.hs b/src/Gate/Computation.hs index 783697d..f44ad6c 100644 --- a/src/Gate/Computation.hs +++ b/src/Gate/Computation.hs @@ -72,6 +72,9 @@ lets_ = mapM (let_ "") commit :: Expr Var_ -> Compute () commit what = Instr (Commit what) +commitExt :: Ext (Expr Var_) -> Compute () +commitExt (MkExt u v) = commit u >> commit v + commitList :: [Expr Var_] -> Compute () commitList = mapM_ commit diff --git a/src/Gate/Constraints.hs b/src/Gate/Constraints.hs index 7df69c6..5fb96b5 100644 --- a/src/Gate/Constraints.hs +++ b/src/Gate/Constraints.hs @@ -49,8 +49,7 @@ gateComputation gate = let j = 8*i let c0 = fromBase (cnst 0) let c1 = fromBase (cnst 1) - let MkExt u v = wireExt (j+6) - c0 * wireExt j * wireExt (j+2) - c1 * wireExt (j+4) - commitList [ u , v ] + commitExt $ wireExt (j+6) - c0 * wireExt j * wireExt (j+2) - c1 * wireExt (j+4) -- `sum b^i * limbs[i] - out = 0`, and `0 <= limb[i] < B` is enforced BaseSumGate num_limbs base @@ -79,8 +78,7 @@ gateComputation gate = MulExtensionGate num_ops -> forM_ (range num_ops) $ \i -> do let j = 6*i - let MkExt u v = wireExt (j+4) - fromBase (cnst 0) * wireExt j * wireExt (j+2) - commitList [ u , v ] + commitExt $ wireExt (j+4) - fromBase (cnst 0) * wireExt j * wireExt (j+2) NoopGate -> return () @@ -93,7 +91,7 @@ gateComputation gate = k -> error ( "gateConstraints/PoseidonGate: unsupported width " ++ show k) PoseidonMdsGate hash_width -> case hash_width of - 12 -> todo -- poseidonMdsGateConstraints + 12 -> poseidonMdsGateConstraints k -> error ( "gateConstraints/PoseidonMdsGate: unsupported width " ++ show k) RandomAccessGate num_bits num_copies num_extra_constants @@ -137,14 +135,16 @@ exponentiationGateConstraints num_power_bits = testCompute :: Compute () -> [FExt] testCompute = runComputation testEvaluationVarsExt -testArtihExtGate = testCompute $ gateComputation (ArithmeticExtensionGate 10) -testBaseSum2 = testCompute $ gateComputation (BaseSumGate 13 2) -testBaseSum3 = testCompute $ gateComputation (BaseSumGate 13 3) -testExpoGate = testCompute $ gateComputation (ExponentiationGate 13) -testMulExtGate = testCompute $ gateComputation (MulExtensionGate 13) -testCosetGate3 = testCompute $ cosetInterpolationGateConstraints $ cosetInterpolationGateConfig (Log2 3) -testCosetGate4 = testCompute $ cosetInterpolationGateConstraints $ cosetInterpolationGateConfig (Log2 4) -testCosetGate5 = testCompute $ cosetInterpolationGateConstraints $ cosetInterpolationGateConfig (Log2 5) -testRandAccGate = testCompute $ randomAccessGateConstraints $ randomAccessGateConfig (Log2 4) +testArtihExtGate = testCompute $ gateComputation (ArithmeticExtensionGate 10) +testBaseSum2 = testCompute $ gateComputation (BaseSumGate 13 2) +testBaseSum3 = testCompute $ gateComputation (BaseSumGate 13 3) +testExpoGate = testCompute $ gateComputation (ExponentiationGate 13) +testMulExtGate = testCompute $ gateComputation (MulExtensionGate 13) +testCosetGate3 = testCompute $ cosetInterpolationGateConstraints $ cosetInterpolationGateConfig (Log2 3) +testCosetGate4 = testCompute $ cosetInterpolationGateConstraints $ cosetInterpolationGateConfig (Log2 4) +testCosetGate5 = testCompute $ cosetInterpolationGateConstraints $ cosetInterpolationGateConfig (Log2 5) +testRandAccGate = testCompute $ randomAccessGateConstraints $ randomAccessGateConfig (Log2 4) +testPoseidonGate = testCompute $ gateComputation (PoseidonGate 12) +testPoseidonMdsGate = testCompute $ gateComputation (PoseidonMdsGate 12) -------------------------------------------------------------------------------- diff --git a/src/Gate/Custom/CosetInterp.hs b/src/Gate/Custom/CosetInterp.hs index 31a69f8..316518b 100644 --- a/src/Gate/Custom/CosetInterp.hs +++ b/src/Gate/Custom/CosetInterp.hs @@ -51,8 +51,7 @@ calcBarycentricWeights locations = weights where cosetInterpolationGateConstraints :: CosetInterpolationGateConfig -> Compute () cosetInterpolationGateConstraints (MkCICfg{..}) = do - let MkExt u v = eval_loc - scaleExt coset_shift shifted_loc - commitList [ u , v ] + commitExt $ eval_loc - scaleExt coset_shift shifted_loc let initials = initial : [ (tmp_eval i , tmp_prod i) | i <- [0..n_intermediates-1] ] let chunks = zip3 chunked_domain chunked_values chunked_weights @@ -61,13 +60,11 @@ cosetInterpolationGateConstraints (MkCICfg{..}) = do let stuff = zipWith worker initials chunks forM_ (zip [0..] (init stuff)) $ \(i,(eval,prod)) -> do - let MkExt u1 v1 = tmp_eval i - eval - let MkExt u2 v2 = tmp_prod i - prod - commitList [ u1 , v1 , u2 , v2 ] + commitExt (tmp_eval i - eval) + commitExt (tmp_prod i - prod) let (final_eval,_) = last stuff - let MkExt u v = eval_result - final_eval - commitList [ u , v ] + commitExt (eval_result - final_eval) where diff --git a/src/Gate/Custom/Poseidon.hs b/src/Gate/Custom/Poseidon.hs index 8f4d3ff..a56073c 100644 --- a/src/Gate/Custom/Poseidon.hs +++ b/src/Gate/Custom/Poseidon.hs @@ -1,5 +1,5 @@ --- | Plonky2's Poseidon gate +-- | Plonky2's Poseidon and PoseidonMds gates -- {-# LANGUAGE StrictData, RecordWildCards #-} @@ -15,6 +15,7 @@ import Control.Monad import Control.Monad.State.Strict import Algebra.Goldilocks +import Algebra.GoldilocksExt import Algebra.Expr import Gate.Vars @@ -45,6 +46,20 @@ type PS = [Expr Var_] -------------------------------------------------------------------------------- +poseidonMdsGateConstraints :: Compute () +poseidonMdsGateConstraints = do + + forM_ [0..11] $ \i -> do + let result = sum [ scaleExt (LitE (mdsMatrixCoeff i j)) (input j) | j<-[0..11] ] :: Ext (Expr_) + commitExt (output i - result) + + where + -- witness variables + input i = wireExt (2* i ) + output i = wireExt (2*(i+12)) + +-------------------------------------------------------------------------------- + poseidonGateConstraints :: Compute () poseidonGateConstraints = do