count the number of operations in constraints, and fix the extreme inefficiency of Poseidon gate

This commit is contained in:
Balazs Komuves 2024-12-15 14:56:16 +01:00
parent 9967a612a2
commit e49a0cfdba
No known key found for this signature in database
GPG Key ID: F63B7AEF18435562
4 changed files with 75 additions and 17 deletions

View File

@ -11,6 +11,8 @@ import Prelude hiding ( (^) )
import Data.Array
import Data.Char
import Data.Monoid
import Data.Semigroup
import Text.Show
@ -35,10 +37,10 @@ data Expr v
-- instance Pretty var => Show (Expr var) where show = pretty
-- | Degree of the expression
exprDegree :: Expr var -> Int
exprDegree = go where
exprDegree :: (var -> Int) -> Expr var -> Int
exprDegree varDeg = go where
go expr = case expr of
VarE _ -> 1
VarE v -> varDeg v
LitE _ -> 0
AddE e1 e2 -> max (go e1) (go e2)
SubE e1 e2 -> max (go e1) (go e2)
@ -68,6 +70,31 @@ mulE :: Expr var -> Expr var -> Expr var
mulE = MulE
-}
--------------------------------------------------------------------------------
-- * Operation counting
data OperCount = MkOperCount
{ numberOfAdds :: Int
, numberOfMuls :: Int
}
deriving (Eq,Show)
instance Semigroup OperCount where
(<>) (MkOperCount a1 m1) (MkOperCount a2 m2) = MkOperCount (a1+a2) (m1+m2)
instance Monoid OperCount where
mempty = MkOperCount 0 0
exprOperCount :: Expr var -> OperCount
exprOperCount = go where
go expr = case expr of
VarE _ -> mempty
LitE e -> mempty
AddE e1 e2 -> go e1 <> go e2 <> MkOperCount 1 0
SubE e1 e2 -> go e1 <> go e2 <> MkOperCount 1 0
MulE e1 e2 -> go e1 <> go e2 <> MkOperCount 0 1
ImgE e -> go e <> MkOperCount 0 1
--------------------------------------------------------------------------------
-- * pretty printing

View File

@ -61,7 +61,13 @@ type Def_ = LocalDef Var_
type Compute a = Program (Instr Var_) a
let_ :: String -> Expr Var_ -> Compute (Expr Var_)
let_ name rhs = VarE <$> Instr (Let name rhs)
let_ name rhs = case rhs of
VarE _ -> return rhs
LitE _ -> return rhs
_ -> VarE <$> Instr (Let name rhs)
lets_ :: [Expr Var_] -> Compute [Expr Var_]
lets_ = mapM (let_ "")
commit :: Expr Var_ -> Compute ()
commit what = Instr (Commit what)
@ -72,8 +78,11 @@ commitList = mapM_ commit
--------------------------------------------------------------------------------
-- | Straightline programs
data LocalDef v
= MkLocalDef Int String (Expr v)
data LocalDef v = MkLocalDef
{ localDefVarIdx :: Int
, localDefVarName :: String
, localDefRHS :: Expr v
}
deriving (Eq,Show)
instance Pretty v => Pretty (LocalDef v) where
@ -109,6 +118,12 @@ compileToStraightLine = fst . go emptyStraightLine where
state' = MkStraightLine (def:localdefs) commits (counter+1)
in (state', LocalVar counter name)
straightLineOperCount :: StraightLine -> OperCount
straightLineOperCount (MkStraightLine{..}) = final where
defs = map exprOperCount $ map localDefRHS localdefs
coms = map exprOperCount $ commits
final = mconcat defs <> mconcat coms
--------------------------------------------------------------------------------
type Scope a = IntMap a

View File

@ -24,11 +24,13 @@ import Hash.Constants
--------------------------------------------------------------------------------
sbox :: Expr Var_ -> Compute (Expr Var_)
sbox x = do
sbox x0 = do
x <- let_ "x1_" x0
x2 <- let_ "x2_" (x *x )
x3 <- let_ "x3_" (x *x2)
x4 <- let_ "x4_" (x2*x2)
return (x3*x4)
x7 <- let_ "x7_" (x3*x4)
return x7
flipFoldM :: (Foldable t, Monad m) => b -> t a -> (b -> a -> m b) -> m b
flipFoldM s0 list action = foldM action s0 list
@ -70,15 +72,15 @@ poseidonGateConstraints = do
return $ mds state4
-- partial rounds
let state' = zipWith (+) phase1 (map LitE $ elems fast_PARTIAL_FIRST_ROUND_CONSTANT)
let state'' = mdsInitPartial state'
state' <- lets_ $ zipWith (+) phase1 (map LitE $ elems fast_PARTIAL_FIRST_ROUND_CONSTANT)
state'' <- lets_ $ mdsInitPartial state'
phase2 <- flipFoldM state'' [0..21] $ \state1 r -> do
let sbox_in = partial_sbox_in r
commit $ (state1!!0) - sbox_in
y <- sbox sbox_in
let z = if r < 21 then y + LitE (fast_PARTIAL_ROUND_CONSTANTS!r) else y
let state2 = z : tail state1
return $ mdsFastPartial r state2
state2 <- lets_ (z : tail state1)
lets_ (mdsFastPartial r state2)
-- final full rounds
phase3 <- flipFoldM phase2 [0..3] $ \state1 r -> do

View File

@ -16,9 +16,12 @@ import Algebra.Goldilocks
import Algebra.GoldilocksExt
import Gate.Base
import Gate.Vars
import Gate.Computation
import Gate.Constraints
import Misc.Pretty
--------------------------------------------------------------------------------
-- | As exported into JSON by our Plonky2 fork
@ -68,15 +71,26 @@ test_fibonacci = do
let const_prg = gateProgram (ConstantGate 2)
let arith_evals = runStraightLine (fmap fromBase arith_row) arith_prg
let const_evals = runStraightLine (fmap fromBase const_row) const_prg
let posei_evals = runStraightLine (fmap fromBase posei_row) posei_prg
let const_evals = runStraightLine (fmap fromBase const_row) const_prg
putStrLn $ "number of constants in ArithmeticGate = " ++ show (length arith_evals)
putStrLn $ "number of constants in ConstantGate = " ++ show (length const_evals)
putStrLn $ "number of constants in PosiedoGate = " ++ show (length posei_evals)
putStrLn $ "number of constraints in ArithmeticGate = " ++ show (length arith_evals)
putStrLn $ "number of constraints in PosiedonGate = " ++ show (length posei_evals)
-- putStrLn $ "number of constraints in PublicInputGate = " ++ show (length pubio_evals)
putStrLn $ "number of constraints in ConstGate = " ++ show (length const_evals)
putStrLn $ "number of operations in ArithmeticGate = " ++ show (straightLineOperCount arith_prg)
putStrLn $ "number of operations in PosiedonGate = " ++ show (straightLineOperCount posei_prg)
putStrLn $ "number of operations in PublicInputGate = " ++ show (straightLineOperCount pubio_prg)
putStrLn $ "number of operations in ConstGate = " ++ show (straightLineOperCount const_prg)
print arith_evals
print const_evals
print posei_evals
print const_evals
-- printStraightLine posei_prg
--------------------------------------------------------------------------------
main :: IO
main = test_fibonacci