diff --git a/src/Algebra/Expr.hs b/src/Algebra/Expr.hs index cc4acb0..82f89dc 100644 --- a/src/Algebra/Expr.hs +++ b/src/Algebra/Expr.hs @@ -11,6 +11,8 @@ import Prelude hiding ( (^) ) import Data.Array import Data.Char +import Data.Monoid +import Data.Semigroup import Text.Show @@ -35,10 +37,10 @@ data Expr v -- instance Pretty var => Show (Expr var) where show = pretty -- | Degree of the expression -exprDegree :: Expr var -> Int -exprDegree = go where +exprDegree :: (var -> Int) -> Expr var -> Int +exprDegree varDeg = go where go expr = case expr of - VarE _ -> 1 + VarE v -> varDeg v LitE _ -> 0 AddE e1 e2 -> max (go e1) (go e2) SubE e1 e2 -> max (go e1) (go e2) @@ -68,6 +70,31 @@ mulE :: Expr var -> Expr var -> Expr var mulE = MulE -} +-------------------------------------------------------------------------------- +-- * Operation counting + +data OperCount = MkOperCount + { numberOfAdds :: Int + , numberOfMuls :: Int + } + deriving (Eq,Show) + +instance Semigroup OperCount where + (<>) (MkOperCount a1 m1) (MkOperCount a2 m2) = MkOperCount (a1+a2) (m1+m2) + +instance Monoid OperCount where + mempty = MkOperCount 0 0 + +exprOperCount :: Expr var -> OperCount +exprOperCount = go where + go expr = case expr of + VarE _ -> mempty + LitE e -> mempty + AddE e1 e2 -> go e1 <> go e2 <> MkOperCount 1 0 + SubE e1 e2 -> go e1 <> go e2 <> MkOperCount 1 0 + MulE e1 e2 -> go e1 <> go e2 <> MkOperCount 0 1 + ImgE e -> go e <> MkOperCount 0 1 + -------------------------------------------------------------------------------- -- * pretty printing diff --git a/src/Gate/Computation.hs b/src/Gate/Computation.hs index c275bb5..a5cd381 100644 --- a/src/Gate/Computation.hs +++ b/src/Gate/Computation.hs @@ -61,7 +61,13 @@ type Def_ = LocalDef Var_ type Compute a = Program (Instr Var_) a let_ :: String -> Expr Var_ -> Compute (Expr Var_) -let_ name rhs = VarE <$> Instr (Let name rhs) +let_ name rhs = case rhs of + VarE _ -> return rhs + LitE _ -> return rhs + _ -> VarE <$> Instr (Let name rhs) + +lets_ :: [Expr Var_] -> Compute [Expr Var_] +lets_ = mapM (let_ "") commit :: Expr Var_ -> Compute () commit what = Instr (Commit what) @@ -72,8 +78,11 @@ commitList = mapM_ commit -------------------------------------------------------------------------------- -- | Straightline programs -data LocalDef v - = MkLocalDef Int String (Expr v) +data LocalDef v = MkLocalDef + { localDefVarIdx :: Int + , localDefVarName :: String + , localDefRHS :: Expr v + } deriving (Eq,Show) instance Pretty v => Pretty (LocalDef v) where @@ -109,6 +118,12 @@ compileToStraightLine = fst . go emptyStraightLine where state' = MkStraightLine (def:localdefs) commits (counter+1) in (state', LocalVar counter name) +straightLineOperCount :: StraightLine -> OperCount +straightLineOperCount (MkStraightLine{..}) = final where + defs = map exprOperCount $ map localDefRHS localdefs + coms = map exprOperCount $ commits + final = mconcat defs <> mconcat coms + -------------------------------------------------------------------------------- type Scope a = IntMap a diff --git a/src/Gate/Poseidon.hs b/src/Gate/Poseidon.hs index 20c9fd4..9bb60c4 100644 --- a/src/Gate/Poseidon.hs +++ b/src/Gate/Poseidon.hs @@ -24,11 +24,13 @@ import Hash.Constants -------------------------------------------------------------------------------- sbox :: Expr Var_ -> Compute (Expr Var_) -sbox x = do +sbox x0 = do + x <- let_ "x1_" x0 x2 <- let_ "x2_" (x *x ) x3 <- let_ "x3_" (x *x2) x4 <- let_ "x4_" (x2*x2) - return (x3*x4) + x7 <- let_ "x7_" (x3*x4) + return x7 flipFoldM :: (Foldable t, Monad m) => b -> t a -> (b -> a -> m b) -> m b flipFoldM s0 list action = foldM action s0 list @@ -70,15 +72,15 @@ poseidonGateConstraints = do return $ mds state4 -- partial rounds - let state' = zipWith (+) phase1 (map LitE $ elems fast_PARTIAL_FIRST_ROUND_CONSTANT) - let state'' = mdsInitPartial state' + state' <- lets_ $ zipWith (+) phase1 (map LitE $ elems fast_PARTIAL_FIRST_ROUND_CONSTANT) + state'' <- lets_ $ mdsInitPartial state' phase2 <- flipFoldM state'' [0..21] $ \state1 r -> do let sbox_in = partial_sbox_in r commit $ (state1!!0) - sbox_in y <- sbox sbox_in let z = if r < 21 then y + LitE (fast_PARTIAL_ROUND_CONSTANTS!r) else y - let state2 = z : tail state1 - return $ mdsFastPartial r state2 + state2 <- lets_ (z : tail state1) + lets_ (mdsFastPartial r state2) -- final full rounds phase3 <- flipFoldM phase2 [0..3] $ \state1 r -> do diff --git a/src/Test/Witness.hs b/src/Test/Witness.hs index 86e4ab0..19d806f 100644 --- a/src/Test/Witness.hs +++ b/src/Test/Witness.hs @@ -16,9 +16,12 @@ import Algebra.Goldilocks import Algebra.GoldilocksExt import Gate.Base +import Gate.Vars import Gate.Computation import Gate.Constraints +import Misc.Pretty + -------------------------------------------------------------------------------- -- | As exported into JSON by our Plonky2 fork @@ -68,15 +71,26 @@ test_fibonacci = do let const_prg = gateProgram (ConstantGate 2) let arith_evals = runStraightLine (fmap fromBase arith_row) arith_prg - let const_evals = runStraightLine (fmap fromBase const_row) const_prg let posei_evals = runStraightLine (fmap fromBase posei_row) posei_prg + let const_evals = runStraightLine (fmap fromBase const_row) const_prg - putStrLn $ "number of constants in ArithmeticGate = " ++ show (length arith_evals) - putStrLn $ "number of constants in ConstantGate = " ++ show (length const_evals) - putStrLn $ "number of constants in PosiedoGate = " ++ show (length posei_evals) + putStrLn $ "number of constraints in ArithmeticGate = " ++ show (length arith_evals) + putStrLn $ "number of constraints in PosiedonGate = " ++ show (length posei_evals) + -- putStrLn $ "number of constraints in PublicInputGate = " ++ show (length pubio_evals) + putStrLn $ "number of constraints in ConstGate = " ++ show (length const_evals) + + putStrLn $ "number of operations in ArithmeticGate = " ++ show (straightLineOperCount arith_prg) + putStrLn $ "number of operations in PosiedonGate = " ++ show (straightLineOperCount posei_prg) + putStrLn $ "number of operations in PublicInputGate = " ++ show (straightLineOperCount pubio_prg) + putStrLn $ "number of operations in ConstGate = " ++ show (straightLineOperCount const_prg) print arith_evals - print const_evals print posei_evals + print const_evals + + -- printStraightLine posei_prg -------------------------------------------------------------------------------- + +main :: IO +main = test_fibonacci