mirror of
https://github.com/logos-storage/plonky2-verifier.git
synced 2026-01-03 06:13:09 +00:00
count the number of operations in constraints, and fix the extreme inefficiency of Poseidon gate
This commit is contained in:
parent
9967a612a2
commit
e49a0cfdba
@ -11,6 +11,8 @@ import Prelude hiding ( (^) )
|
||||
|
||||
import Data.Array
|
||||
import Data.Char
|
||||
import Data.Monoid
|
||||
import Data.Semigroup
|
||||
|
||||
import Text.Show
|
||||
|
||||
@ -35,10 +37,10 @@ data Expr v
|
||||
-- instance Pretty var => Show (Expr var) where show = pretty
|
||||
|
||||
-- | Degree of the expression
|
||||
exprDegree :: Expr var -> Int
|
||||
exprDegree = go where
|
||||
exprDegree :: (var -> Int) -> Expr var -> Int
|
||||
exprDegree varDeg = go where
|
||||
go expr = case expr of
|
||||
VarE _ -> 1
|
||||
VarE v -> varDeg v
|
||||
LitE _ -> 0
|
||||
AddE e1 e2 -> max (go e1) (go e2)
|
||||
SubE e1 e2 -> max (go e1) (go e2)
|
||||
@ -68,6 +70,31 @@ mulE :: Expr var -> Expr var -> Expr var
|
||||
mulE = MulE
|
||||
-}
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- * Operation counting
|
||||
|
||||
data OperCount = MkOperCount
|
||||
{ numberOfAdds :: Int
|
||||
, numberOfMuls :: Int
|
||||
}
|
||||
deriving (Eq,Show)
|
||||
|
||||
instance Semigroup OperCount where
|
||||
(<>) (MkOperCount a1 m1) (MkOperCount a2 m2) = MkOperCount (a1+a2) (m1+m2)
|
||||
|
||||
instance Monoid OperCount where
|
||||
mempty = MkOperCount 0 0
|
||||
|
||||
exprOperCount :: Expr var -> OperCount
|
||||
exprOperCount = go where
|
||||
go expr = case expr of
|
||||
VarE _ -> mempty
|
||||
LitE e -> mempty
|
||||
AddE e1 e2 -> go e1 <> go e2 <> MkOperCount 1 0
|
||||
SubE e1 e2 -> go e1 <> go e2 <> MkOperCount 1 0
|
||||
MulE e1 e2 -> go e1 <> go e2 <> MkOperCount 0 1
|
||||
ImgE e -> go e <> MkOperCount 0 1
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- * pretty printing
|
||||
|
||||
|
||||
@ -61,7 +61,13 @@ type Def_ = LocalDef Var_
|
||||
type Compute a = Program (Instr Var_) a
|
||||
|
||||
let_ :: String -> Expr Var_ -> Compute (Expr Var_)
|
||||
let_ name rhs = VarE <$> Instr (Let name rhs)
|
||||
let_ name rhs = case rhs of
|
||||
VarE _ -> return rhs
|
||||
LitE _ -> return rhs
|
||||
_ -> VarE <$> Instr (Let name rhs)
|
||||
|
||||
lets_ :: [Expr Var_] -> Compute [Expr Var_]
|
||||
lets_ = mapM (let_ "")
|
||||
|
||||
commit :: Expr Var_ -> Compute ()
|
||||
commit what = Instr (Commit what)
|
||||
@ -72,8 +78,11 @@ commitList = mapM_ commit
|
||||
--------------------------------------------------------------------------------
|
||||
-- | Straightline programs
|
||||
|
||||
data LocalDef v
|
||||
= MkLocalDef Int String (Expr v)
|
||||
data LocalDef v = MkLocalDef
|
||||
{ localDefVarIdx :: Int
|
||||
, localDefVarName :: String
|
||||
, localDefRHS :: Expr v
|
||||
}
|
||||
deriving (Eq,Show)
|
||||
|
||||
instance Pretty v => Pretty (LocalDef v) where
|
||||
@ -109,6 +118,12 @@ compileToStraightLine = fst . go emptyStraightLine where
|
||||
state' = MkStraightLine (def:localdefs) commits (counter+1)
|
||||
in (state', LocalVar counter name)
|
||||
|
||||
straightLineOperCount :: StraightLine -> OperCount
|
||||
straightLineOperCount (MkStraightLine{..}) = final where
|
||||
defs = map exprOperCount $ map localDefRHS localdefs
|
||||
coms = map exprOperCount $ commits
|
||||
final = mconcat defs <> mconcat coms
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
type Scope a = IntMap a
|
||||
|
||||
@ -24,11 +24,13 @@ import Hash.Constants
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
sbox :: Expr Var_ -> Compute (Expr Var_)
|
||||
sbox x = do
|
||||
sbox x0 = do
|
||||
x <- let_ "x1_" x0
|
||||
x2 <- let_ "x2_" (x *x )
|
||||
x3 <- let_ "x3_" (x *x2)
|
||||
x4 <- let_ "x4_" (x2*x2)
|
||||
return (x3*x4)
|
||||
x7 <- let_ "x7_" (x3*x4)
|
||||
return x7
|
||||
|
||||
flipFoldM :: (Foldable t, Monad m) => b -> t a -> (b -> a -> m b) -> m b
|
||||
flipFoldM s0 list action = foldM action s0 list
|
||||
@ -70,15 +72,15 @@ poseidonGateConstraints = do
|
||||
return $ mds state4
|
||||
|
||||
-- partial rounds
|
||||
let state' = zipWith (+) phase1 (map LitE $ elems fast_PARTIAL_FIRST_ROUND_CONSTANT)
|
||||
let state'' = mdsInitPartial state'
|
||||
state' <- lets_ $ zipWith (+) phase1 (map LitE $ elems fast_PARTIAL_FIRST_ROUND_CONSTANT)
|
||||
state'' <- lets_ $ mdsInitPartial state'
|
||||
phase2 <- flipFoldM state'' [0..21] $ \state1 r -> do
|
||||
let sbox_in = partial_sbox_in r
|
||||
commit $ (state1!!0) - sbox_in
|
||||
y <- sbox sbox_in
|
||||
let z = if r < 21 then y + LitE (fast_PARTIAL_ROUND_CONSTANTS!r) else y
|
||||
let state2 = z : tail state1
|
||||
return $ mdsFastPartial r state2
|
||||
state2 <- lets_ (z : tail state1)
|
||||
lets_ (mdsFastPartial r state2)
|
||||
|
||||
-- final full rounds
|
||||
phase3 <- flipFoldM phase2 [0..3] $ \state1 r -> do
|
||||
|
||||
@ -16,9 +16,12 @@ import Algebra.Goldilocks
|
||||
import Algebra.GoldilocksExt
|
||||
|
||||
import Gate.Base
|
||||
import Gate.Vars
|
||||
import Gate.Computation
|
||||
import Gate.Constraints
|
||||
|
||||
import Misc.Pretty
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
-- | As exported into JSON by our Plonky2 fork
|
||||
@ -68,15 +71,26 @@ test_fibonacci = do
|
||||
let const_prg = gateProgram (ConstantGate 2)
|
||||
|
||||
let arith_evals = runStraightLine (fmap fromBase arith_row) arith_prg
|
||||
let const_evals = runStraightLine (fmap fromBase const_row) const_prg
|
||||
let posei_evals = runStraightLine (fmap fromBase posei_row) posei_prg
|
||||
let const_evals = runStraightLine (fmap fromBase const_row) const_prg
|
||||
|
||||
putStrLn $ "number of constants in ArithmeticGate = " ++ show (length arith_evals)
|
||||
putStrLn $ "number of constants in ConstantGate = " ++ show (length const_evals)
|
||||
putStrLn $ "number of constants in PosiedoGate = " ++ show (length posei_evals)
|
||||
putStrLn $ "number of constraints in ArithmeticGate = " ++ show (length arith_evals)
|
||||
putStrLn $ "number of constraints in PosiedonGate = " ++ show (length posei_evals)
|
||||
-- putStrLn $ "number of constraints in PublicInputGate = " ++ show (length pubio_evals)
|
||||
putStrLn $ "number of constraints in ConstGate = " ++ show (length const_evals)
|
||||
|
||||
putStrLn $ "number of operations in ArithmeticGate = " ++ show (straightLineOperCount arith_prg)
|
||||
putStrLn $ "number of operations in PosiedonGate = " ++ show (straightLineOperCount posei_prg)
|
||||
putStrLn $ "number of operations in PublicInputGate = " ++ show (straightLineOperCount pubio_prg)
|
||||
putStrLn $ "number of operations in ConstGate = " ++ show (straightLineOperCount const_prg)
|
||||
|
||||
print arith_evals
|
||||
print const_evals
|
||||
print posei_evals
|
||||
print const_evals
|
||||
|
||||
-- printStraightLine posei_prg
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
main :: IO
|
||||
main = test_fibonacci
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user