some refactoring (algebra)

This commit is contained in:
Balazs Komuves 2024-12-13 20:42:44 +01:00
parent 645d2024ed
commit bf9fb3e969
No known key found for this signature in database
GPG Key ID: F63B7AEF18435562
16 changed files with 171 additions and 132 deletions

138
src/Algebra/Expr.hs Normal file
View File

@ -0,0 +1,138 @@
-- | Polynomial expressions
--
{-# LANGUAGE StrictData, RecordWildCards #-}
module Algebra.Expr where
--------------------------------------------------------------------------------
import Data.Array
import Data.Char
import Text.Show
import Algebra.Goldilocks
import Algebra.GoldilocksExt
import Gate.Base
--------------------------------------------------------------------------------
-- * Polynomial expressions
data Expr v
= VarE v -- ^ a variable
| LitE F -- ^ constant literal
| ScaleE F (Expr v) -- ^ linear scaling by a constant
| ImagE (Expr v) -- ^ multiplies by the field extension generator X
| SumE [Expr v] -- ^ sum of expressions
| ProdE [Expr v] -- ^ product of expressions
| PowE (Expr v) Int -- ^ exponentiation
deriving (Eq) -- ,Show)
instance Pretty var => Show (Expr var) where show = pretty
-- | Degree of the expression
exprDegree :: Expr var -> Int
exprDegree = go where
go expr = case expr of
VarE _ -> 1
LitE _ -> 0
ScaleE _ e -> go e
ImagE e -> go e
SumE es -> if null es then 0 else maximum (map go es)
ProdE es -> sum (map go es)
PowE e n -> n * go e
instance Num (Expr var) where
fromInteger = LitE . fromInteger
negate = negE
(+) = addE
(-) = subE
(*) = mulE
abs = error "Expr/abs"
signum = error "Expr/signum"
negE :: Expr var -> Expr var
negE (ScaleE s e) = ScaleE (negate s) e
negE e = ScaleE (-1) e
addE :: Expr var -> Expr var -> Expr var
addE (SumE es) (SumE fs) = SumE (es++fs )
addE e (SumE fs) = SumE (e : fs )
addE (SumE es) f = SumE (es++[f])
addE e f = SumE [e,f]
subE :: Expr var -> Expr var -> Expr var
subE e f = addE e (negate f)
sclE :: F -> Expr var -> Expr var
sclE s (ScaleE t e) = sclE (s*t) e
sclE s e = ScaleE s e
mulE :: Expr var -> Expr var -> Expr var
mulE (ScaleE s e) (ScaleE t f) = sclE (s*t) (mulE e f)
mulE (ScaleE s e) f = sclE s (mulE e f)
mulE (LitE s) f = sclE s f
mulE e (LitE t) = sclE t e
mulE e (ScaleE t f) = sclE t (mulE e f)
mulE (ProdE es) (ProdE fs) = ProdE (es++fs )
mulE e (ProdE fs) = ProdE (e : fs )
mulE (ProdE es) f = ProdE (es++[f])
mulE e f = ProdE [e,f]
--------------------------------------------------------------------------------
-- * pretty printing
-- | TODO: maybe move this somewhere else
class Pretty a where
prettyPrec :: Int -> a -> (String -> String)
pretty :: Pretty a => a -> String
pretty x = prettyPrec 0 x ""
instance Pretty F where prettyPrec _ x = shows x
instance Pretty FExt where prettyPrec _ x = shows x
instance Pretty var => Pretty (Expr var) where
prettyPrec d expr =
case expr of
VarE v -> prettyPrec 0 v
LitE x -> prettyPrec 0 x
ScaleE s e -> prettyPrec 0 s . showString " * " . showParen (d > mul_prec) (prettyPrec mul_prec e)
ImagE e -> showString "X*" . showParen (d > mul_prec) (prettyPrec mul_prec e)
SumE es -> showParen (d > add_prec) $ intercalates " + " $ map (prettyPrec add_prec) es
ProdE es -> showParen (d > mul_prec) $ intercalates " * " $ map (prettyPrec mul_prec) es
PowE e k -> showParen (d > pow_prec) $ (prettyPrec pow_prec e) . showString ("^" ++ show k)
where
add_prec = 5
mul_prec = 6
pow_prec = 7
intercalates sep = go where
go [] = id
go [x] = x
go (x:xs) = x . showString sep . go xs
--------------------------------------------------------------------------------
-- * Evaluation
{-
class (Eq a, Show a, Num a, Fractional a) => EvalField a where
fromGoldilocks :: Goldilocks -> a
instance EvalField F where fromGoldilocks = id
instance EvalField FExt where fromGoldilocks = fromBase
-}
evalExprWith :: (var -> FExt) -> Expr var -> FExt
evalExprWith evalVar expr = go expr where
go e = case e of
VarE v -> evalVar v
LitE x -> fromBase x
ScaleE s e -> fromBase s * go e
ImagE e -> (MkExt 0 1) * go e
SumE es -> sum (map go es)
ProdE es -> product (map go es)
PowE e n -> powExt (go e) (fromIntegral n)
--------------------------------------------------------------------------------

View File

@ -2,7 +2,7 @@
-- | Reference (simple, but very slow) implementation of the Goldilocks prime field
{-# LANGUAGE BangPatterns, NumericUnderscores #-}
module Goldilocks where
module Algebra.Goldilocks where
--------------------------------------------------------------------------------

View File

@ -6,7 +6,7 @@
-- (@X^2 - 7@ is the smallest such irreducible polynomial over Goldilocks)
--
module GoldilocksExt where
module Algebra.GoldilocksExt where
--------------------------------------------------------------------------------
@ -15,7 +15,7 @@ import Data.Ratio
import Data.Aeson ( ToJSON(..), FromJSON(..) )
import Goldilocks
import Algebra.Goldilocks
--------------------------------------------------------------------------------

View File

@ -10,8 +10,8 @@ import Control.Monad
import Data.Bits
import Goldilocks
import GoldilocksExt
import Algebra.Goldilocks
import Algebra.GoldilocksExt
import Hash
import Digest
import Types

View File

@ -13,7 +13,7 @@ import Control.Monad.Identity
import qualified Control.Monad.State.Strict as S
import Control.Monad.IO.Class
import Goldilocks
import Algebra.Goldilocks
import Digest
import Challenge.Pure ( DuplexState, Squeeze, Absorb )
import qualified Challenge.Pure as Pure

View File

@ -15,8 +15,8 @@ module Challenge.Pure
import Data.Array
import Goldilocks
import GoldilocksExt
import Algebra.Goldilocks
import Algebra.GoldilocksExt
import Poseidon
import Digest
import Types

View File

@ -6,8 +6,8 @@ module Challenge.Verifier where
--------------------------------------------------------------------------------
import Goldilocks
import GoldilocksExt
import Algebra.Goldilocks
import Algebra.GoldilocksExt
import Hash
import Digest
import Types

View File

@ -8,7 +8,8 @@ module Constants where
--------------------------------------------------------------------------------
import Data.Array
import Goldilocks
import Algebra.Goldilocks
--------------------------------------------------------------------------------

View File

@ -11,7 +11,7 @@ import Data.Bits
import GHC.Generics
import Data.Aeson ( FromJSON(..) , ToJSON(..) , object , withObject , (.=) , (.:) )
import Goldilocks
import Algebra.Goldilocks
--------------------------------------------------------------------------------

View File

@ -11,7 +11,7 @@ import Data.Word
import Data.Aeson
import GHC.Generics
import Goldilocks
import Algebra.Goldilocks
--------------------------------------------------------------------------------

View File

@ -15,8 +15,9 @@ import Data.Char
import Text.Show
import Goldilocks
import GoldilocksExt
import Algebra.Goldilocks
import Algebra.GoldilocksExt
import Algebra.Expr
import Gate.Base
@ -31,103 +32,12 @@ data Var
| PIV Int -- ^ public input hash variable
deriving (Eq,Ord,Show)
data Expr
= VarE Var -- ^ a variable
| LitE F -- ^ constant literal
| ScaleE F Expr -- ^ linear scaling by a constant
| ImagE Expr -- ^ multiplies by the field extension generator X
| SumE [Expr] -- ^ sum of expressions
| ProdE [Expr] -- ^ product of expressions
| PowE Expr Int -- ^ exponentiation
deriving (Eq) -- ,Show)
instance Show Expr where show = pretty
-- | Degree of the expression
exprDegree :: Expr -> Int
exprDegree = go where
go expr = case expr of
VarE _ -> 1
LitE _ -> 0
ScaleE _ e -> go e
ImagE e -> go e
SumE es -> if null es then 0 else maximum (map go es)
ProdE es -> sum (map go es)
PowE e n -> n * go e
instance Num Expr where
fromInteger = LitE . fromInteger
negate = negE
(+) = addE
(-) = subE
(*) = mulE
abs = error "Expr/abs"
signum = error "Expr/signum"
negE :: Expr -> Expr
negE (ScaleE s e) = ScaleE (negate s) e
negE e = ScaleE (-1) e
addE :: Expr -> Expr -> Expr
addE (SumE es) (SumE fs) = SumE (es++fs )
addE e (SumE fs) = SumE (e : fs )
addE (SumE es) f = SumE (es++[f])
addE e f = SumE [e,f]
subE :: Expr -> Expr -> Expr
subE e f = addE e (negate f)
sclE :: F -> Expr -> Expr
sclE s (ScaleE t e) = sclE (s*t) e
sclE s e = ScaleE s e
mulE :: Expr -> Expr -> Expr
mulE (ScaleE s e) (ScaleE t f) = sclE (s*t) (mulE e f)
mulE (ScaleE s e) f = sclE s (mulE e f)
mulE (LitE s) f = sclE s f
mulE e (LitE t) = sclE t e
mulE e (ScaleE t f) = sclE t (mulE e f)
mulE (ProdE es) (ProdE fs) = ProdE (es++fs )
mulE e (ProdE fs) = ProdE (e : fs )
mulE (ProdE es) f = ProdE (es++[f])
mulE e f = ProdE [e,f]
--------------------------------------------------------------------------------
-- * pretty printing
class Pretty a where
prettyPrec :: Int -> a -> (String -> String)
pretty :: Pretty a => a -> String
pretty x = prettyPrec 0 x ""
instance Pretty F where prettyPrec _ x = shows x
instance Pretty FExt where prettyPrec _ x = shows x
instance Pretty Var where
prettyPrec _ v = case v of
SelV k -> showString ("s" ++ show k)
ConstV k -> showString ("c" ++ show k)
WireV k -> showString ("w" ++ show k)
PIV k -> showString ("h" ++ show k)
instance Pretty Expr where
prettyPrec d expr =
case expr of
VarE v -> prettyPrec 0 v
LitE x -> prettyPrec 0 x
ScaleE s e -> prettyPrec 0 s . showString " * " . showParen (d > mul_prec) (prettyPrec mul_prec e)
ImagE e -> showString "X*" . showParen (d > mul_prec) (prettyPrec mul_prec e)
SumE es -> showParen (d > add_prec) $ intercalates " + " $ map (prettyPrec add_prec) es
ProdE es -> showParen (d > mul_prec) $ intercalates " * " $ map (prettyPrec mul_prec) es
PowE e k -> showParen (d > pow_prec) $ (prettyPrec pow_prec e) . showString ("^" ++ show k)
where
add_prec = 5
mul_prec = 6
pow_prec = 7
intercalates sep = go where
go [] = id
go [x] = x
go (x:xs) = x . showString sep . go xs
--------------------------------------------------------------------------------
@ -145,29 +55,19 @@ data EvaluationVars a = MkEvaluationVars
--------------------------------------------------------------------------------
-- * Evaluation
class (Eq a, Show a, Num a, Fractional a) => EvalField a where
fromGoldilocks :: Goldilocks -> a
instance EvalField F where fromGoldilocks = id
instance EvalField FExt where fromGoldilocks = fromBase
evalExpr :: EvalField a => Expr -> EvaluationVars a -> a
evalExpr expr (MkEvaluationVars{..}) = go expr where
go e = case e of
VarE v -> case v of
SelV k -> local_selectors ! k
ConstV k -> local_constants ! k
WireV k -> local_wires ! k
PIV k -> fromGoldilocks (public_inputs_hash !! k)
LitE x -> fromGoldilocks x
SumE es -> sum (map go es)
ProdE es -> product (map go es)
evalExpr :: Expr Var -> EvaluationVars FExt -> FExt
evalExpr expr (MkEvaluationVars{..}) = evalExprWith f expr where
f v = case v of
SelV k -> local_selectors ! k
ConstV k -> local_constants ! k
WireV k -> local_wires ! k
PIV k -> fromBase (public_inputs_hash !! k)
--------------------------------------------------------------------------------
-- * Gate constraints
-- | Returns the (symbolic) constraints for the given gate
gateConstraints :: Gate -> [Expr]
gateConstraints :: Gate -> [Expr Var]
gateConstraints gate =
case gate of

View File

@ -18,7 +18,7 @@ import GHC.Generics
import "parsec1" Text.ParserCombinators.Parsec
import Goldilocks
import Algebra.Goldilocks
import Gate.Base
--------------------------------------------------------------------------------

View File

@ -7,7 +7,7 @@ import Data.Array
import Data.Word
import Data.Bits
import Goldilocks
import Algebra.Goldilocks
import Poseidon
import Digest

View File

@ -14,7 +14,7 @@ import Data.Word
import Data.Array (Array)
import Data.Array.IArray
import Goldilocks
import Algebra.Goldilocks
import Constants
import Digest

View File

@ -15,8 +15,8 @@ import qualified Data.ByteString.Lazy.Char8 as L
import GHC.Generics
import Goldilocks
import GoldilocksExt
import Algebra.Goldilocks
import Algebra.GoldilocksExt
import Digest
import Gate.Base
import Gate.Parser

View File

@ -8,7 +8,7 @@ import Data.Aeson
import Types
import Hash
import Digest
import Goldilocks
import Algebra.Goldilocks
import Challenge.Verifier
import qualified Data.ByteString.Char8 as B
@ -28,7 +28,7 @@ main = do
let Just common_data = decode text_common :: Maybe CommonCircuitData
let Just verifier_data = decode text_vkey :: Maybe VerifierOnlyCircuitData
let Just proof_data = decode text_proof :: Maybe ProofWithPublicInputs
let challenges = proofChallenges common_data verifier_data proof_data
print challenges