mirror of
https://github.com/logos-storage/plonky2-verifier.git
synced 2026-01-10 09:43:08 +00:00
count the number of operations in constraints, and fix the extreme inefficiency of Poseidon gate
This commit is contained in:
parent
9967a612a2
commit
e49a0cfdba
@ -11,6 +11,8 @@ import Prelude hiding ( (^) )
|
|||||||
|
|
||||||
import Data.Array
|
import Data.Array
|
||||||
import Data.Char
|
import Data.Char
|
||||||
|
import Data.Monoid
|
||||||
|
import Data.Semigroup
|
||||||
|
|
||||||
import Text.Show
|
import Text.Show
|
||||||
|
|
||||||
@ -35,10 +37,10 @@ data Expr v
|
|||||||
-- instance Pretty var => Show (Expr var) where show = pretty
|
-- instance Pretty var => Show (Expr var) where show = pretty
|
||||||
|
|
||||||
-- | Degree of the expression
|
-- | Degree of the expression
|
||||||
exprDegree :: Expr var -> Int
|
exprDegree :: (var -> Int) -> Expr var -> Int
|
||||||
exprDegree = go where
|
exprDegree varDeg = go where
|
||||||
go expr = case expr of
|
go expr = case expr of
|
||||||
VarE _ -> 1
|
VarE v -> varDeg v
|
||||||
LitE _ -> 0
|
LitE _ -> 0
|
||||||
AddE e1 e2 -> max (go e1) (go e2)
|
AddE e1 e2 -> max (go e1) (go e2)
|
||||||
SubE e1 e2 -> max (go e1) (go e2)
|
SubE e1 e2 -> max (go e1) (go e2)
|
||||||
@ -68,6 +70,31 @@ mulE :: Expr var -> Expr var -> Expr var
|
|||||||
mulE = MulE
|
mulE = MulE
|
||||||
-}
|
-}
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
-- * Operation counting
|
||||||
|
|
||||||
|
data OperCount = MkOperCount
|
||||||
|
{ numberOfAdds :: Int
|
||||||
|
, numberOfMuls :: Int
|
||||||
|
}
|
||||||
|
deriving (Eq,Show)
|
||||||
|
|
||||||
|
instance Semigroup OperCount where
|
||||||
|
(<>) (MkOperCount a1 m1) (MkOperCount a2 m2) = MkOperCount (a1+a2) (m1+m2)
|
||||||
|
|
||||||
|
instance Monoid OperCount where
|
||||||
|
mempty = MkOperCount 0 0
|
||||||
|
|
||||||
|
exprOperCount :: Expr var -> OperCount
|
||||||
|
exprOperCount = go where
|
||||||
|
go expr = case expr of
|
||||||
|
VarE _ -> mempty
|
||||||
|
LitE e -> mempty
|
||||||
|
AddE e1 e2 -> go e1 <> go e2 <> MkOperCount 1 0
|
||||||
|
SubE e1 e2 -> go e1 <> go e2 <> MkOperCount 1 0
|
||||||
|
MulE e1 e2 -> go e1 <> go e2 <> MkOperCount 0 1
|
||||||
|
ImgE e -> go e <> MkOperCount 0 1
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
-- * pretty printing
|
-- * pretty printing
|
||||||
|
|
||||||
|
|||||||
@ -61,7 +61,13 @@ type Def_ = LocalDef Var_
|
|||||||
type Compute a = Program (Instr Var_) a
|
type Compute a = Program (Instr Var_) a
|
||||||
|
|
||||||
let_ :: String -> Expr Var_ -> Compute (Expr Var_)
|
let_ :: String -> Expr Var_ -> Compute (Expr Var_)
|
||||||
let_ name rhs = VarE <$> Instr (Let name rhs)
|
let_ name rhs = case rhs of
|
||||||
|
VarE _ -> return rhs
|
||||||
|
LitE _ -> return rhs
|
||||||
|
_ -> VarE <$> Instr (Let name rhs)
|
||||||
|
|
||||||
|
lets_ :: [Expr Var_] -> Compute [Expr Var_]
|
||||||
|
lets_ = mapM (let_ "")
|
||||||
|
|
||||||
commit :: Expr Var_ -> Compute ()
|
commit :: Expr Var_ -> Compute ()
|
||||||
commit what = Instr (Commit what)
|
commit what = Instr (Commit what)
|
||||||
@ -72,8 +78,11 @@ commitList = mapM_ commit
|
|||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
-- | Straightline programs
|
-- | Straightline programs
|
||||||
|
|
||||||
data LocalDef v
|
data LocalDef v = MkLocalDef
|
||||||
= MkLocalDef Int String (Expr v)
|
{ localDefVarIdx :: Int
|
||||||
|
, localDefVarName :: String
|
||||||
|
, localDefRHS :: Expr v
|
||||||
|
}
|
||||||
deriving (Eq,Show)
|
deriving (Eq,Show)
|
||||||
|
|
||||||
instance Pretty v => Pretty (LocalDef v) where
|
instance Pretty v => Pretty (LocalDef v) where
|
||||||
@ -109,6 +118,12 @@ compileToStraightLine = fst . go emptyStraightLine where
|
|||||||
state' = MkStraightLine (def:localdefs) commits (counter+1)
|
state' = MkStraightLine (def:localdefs) commits (counter+1)
|
||||||
in (state', LocalVar counter name)
|
in (state', LocalVar counter name)
|
||||||
|
|
||||||
|
straightLineOperCount :: StraightLine -> OperCount
|
||||||
|
straightLineOperCount (MkStraightLine{..}) = final where
|
||||||
|
defs = map exprOperCount $ map localDefRHS localdefs
|
||||||
|
coms = map exprOperCount $ commits
|
||||||
|
final = mconcat defs <> mconcat coms
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
type Scope a = IntMap a
|
type Scope a = IntMap a
|
||||||
|
|||||||
@ -24,11 +24,13 @@ import Hash.Constants
|
|||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
sbox :: Expr Var_ -> Compute (Expr Var_)
|
sbox :: Expr Var_ -> Compute (Expr Var_)
|
||||||
sbox x = do
|
sbox x0 = do
|
||||||
|
x <- let_ "x1_" x0
|
||||||
x2 <- let_ "x2_" (x *x )
|
x2 <- let_ "x2_" (x *x )
|
||||||
x3 <- let_ "x3_" (x *x2)
|
x3 <- let_ "x3_" (x *x2)
|
||||||
x4 <- let_ "x4_" (x2*x2)
|
x4 <- let_ "x4_" (x2*x2)
|
||||||
return (x3*x4)
|
x7 <- let_ "x7_" (x3*x4)
|
||||||
|
return x7
|
||||||
|
|
||||||
flipFoldM :: (Foldable t, Monad m) => b -> t a -> (b -> a -> m b) -> m b
|
flipFoldM :: (Foldable t, Monad m) => b -> t a -> (b -> a -> m b) -> m b
|
||||||
flipFoldM s0 list action = foldM action s0 list
|
flipFoldM s0 list action = foldM action s0 list
|
||||||
@ -70,15 +72,15 @@ poseidonGateConstraints = do
|
|||||||
return $ mds state4
|
return $ mds state4
|
||||||
|
|
||||||
-- partial rounds
|
-- partial rounds
|
||||||
let state' = zipWith (+) phase1 (map LitE $ elems fast_PARTIAL_FIRST_ROUND_CONSTANT)
|
state' <- lets_ $ zipWith (+) phase1 (map LitE $ elems fast_PARTIAL_FIRST_ROUND_CONSTANT)
|
||||||
let state'' = mdsInitPartial state'
|
state'' <- lets_ $ mdsInitPartial state'
|
||||||
phase2 <- flipFoldM state'' [0..21] $ \state1 r -> do
|
phase2 <- flipFoldM state'' [0..21] $ \state1 r -> do
|
||||||
let sbox_in = partial_sbox_in r
|
let sbox_in = partial_sbox_in r
|
||||||
commit $ (state1!!0) - sbox_in
|
commit $ (state1!!0) - sbox_in
|
||||||
y <- sbox sbox_in
|
y <- sbox sbox_in
|
||||||
let z = if r < 21 then y + LitE (fast_PARTIAL_ROUND_CONSTANTS!r) else y
|
let z = if r < 21 then y + LitE (fast_PARTIAL_ROUND_CONSTANTS!r) else y
|
||||||
let state2 = z : tail state1
|
state2 <- lets_ (z : tail state1)
|
||||||
return $ mdsFastPartial r state2
|
lets_ (mdsFastPartial r state2)
|
||||||
|
|
||||||
-- final full rounds
|
-- final full rounds
|
||||||
phase3 <- flipFoldM phase2 [0..3] $ \state1 r -> do
|
phase3 <- flipFoldM phase2 [0..3] $ \state1 r -> do
|
||||||
|
|||||||
@ -16,9 +16,12 @@ import Algebra.Goldilocks
|
|||||||
import Algebra.GoldilocksExt
|
import Algebra.GoldilocksExt
|
||||||
|
|
||||||
import Gate.Base
|
import Gate.Base
|
||||||
|
import Gate.Vars
|
||||||
import Gate.Computation
|
import Gate.Computation
|
||||||
import Gate.Constraints
|
import Gate.Constraints
|
||||||
|
|
||||||
|
import Misc.Pretty
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
-- | As exported into JSON by our Plonky2 fork
|
-- | As exported into JSON by our Plonky2 fork
|
||||||
@ -68,15 +71,26 @@ test_fibonacci = do
|
|||||||
let const_prg = gateProgram (ConstantGate 2)
|
let const_prg = gateProgram (ConstantGate 2)
|
||||||
|
|
||||||
let arith_evals = runStraightLine (fmap fromBase arith_row) arith_prg
|
let arith_evals = runStraightLine (fmap fromBase arith_row) arith_prg
|
||||||
let const_evals = runStraightLine (fmap fromBase const_row) const_prg
|
|
||||||
let posei_evals = runStraightLine (fmap fromBase posei_row) posei_prg
|
let posei_evals = runStraightLine (fmap fromBase posei_row) posei_prg
|
||||||
|
let const_evals = runStraightLine (fmap fromBase const_row) const_prg
|
||||||
|
|
||||||
putStrLn $ "number of constants in ArithmeticGate = " ++ show (length arith_evals)
|
putStrLn $ "number of constraints in ArithmeticGate = " ++ show (length arith_evals)
|
||||||
putStrLn $ "number of constants in ConstantGate = " ++ show (length const_evals)
|
putStrLn $ "number of constraints in PosiedonGate = " ++ show (length posei_evals)
|
||||||
putStrLn $ "number of constants in PosiedoGate = " ++ show (length posei_evals)
|
-- putStrLn $ "number of constraints in PublicInputGate = " ++ show (length pubio_evals)
|
||||||
|
putStrLn $ "number of constraints in ConstGate = " ++ show (length const_evals)
|
||||||
|
|
||||||
|
putStrLn $ "number of operations in ArithmeticGate = " ++ show (straightLineOperCount arith_prg)
|
||||||
|
putStrLn $ "number of operations in PosiedonGate = " ++ show (straightLineOperCount posei_prg)
|
||||||
|
putStrLn $ "number of operations in PublicInputGate = " ++ show (straightLineOperCount pubio_prg)
|
||||||
|
putStrLn $ "number of operations in ConstGate = " ++ show (straightLineOperCount const_prg)
|
||||||
|
|
||||||
print arith_evals
|
print arith_evals
|
||||||
print const_evals
|
|
||||||
print posei_evals
|
print posei_evals
|
||||||
|
print const_evals
|
||||||
|
|
||||||
|
-- printStraightLine posei_prg
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
main :: IO
|
||||||
|
main = test_fibonacci
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user