diff --git a/README.md b/README.md index a22426e..2524db7 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ A standalone Plonky2 verifier ----------------------------- -This is a (WIP) implementation of a Plonky2 verifier in Haskell. +This is a (WIP) implementation of a Plonky2 verifier written in Haskell. [Plonky2](https://github.com/0xPolygonZero/plonky2/) is a zero-knowledge proof system developed by Polygon Zero, optimized for recursive proofs. @@ -15,7 +15,7 @@ Another goal is to be a basis for further tooling (for example: estimating verifier costs, helping the design of recursive circuits, generating Plonky2 verifier circuits for other proof systems, etc) -Note: It's deliberately no a goal for this verifier to be efficient; instead we +Note: It's deliberately not a goal for this verifier to be efficient; instead we try to focus on simplicity. @@ -42,7 +42,7 @@ Supported gates: - [x] MulExtensionGate - [x] NoopGate - [x] PublicInputGate -- [ ] PoseidonGate +- [x] PoseidonGate - [ ] PoseidonMdsGate - [ ] RandomAccessGate - [ ] ReducingGate @@ -52,6 +52,6 @@ Optional features: - [ ] Field extensions with degree higher than 2 - [ ] Being parametric over the field choice -- [ ] Supporting different hash function choices +- [ ] Supporting different hash functions diff --git a/src/Algebra/Expr.hs b/src/Algebra/Expr.hs index 649ec51..5f09e8f 100644 --- a/src/Algebra/Expr.hs +++ b/src/Algebra/Expr.hs @@ -7,6 +7,8 @@ module Algebra.Expr where -------------------------------------------------------------------------------- +import Prelude hiding ( (^) ) + import Data.Array import Data.Char @@ -16,102 +18,76 @@ import Algebra.Goldilocks import Algebra.GoldilocksExt import Gate.Base +import Misc.Pretty -------------------------------------------------------------------------------- -- * Polynomial expressions data Expr v - = VarE v -- ^ a variable - | LitE F -- ^ constant literal - | ScaleE F (Expr v) -- ^ linear scaling by a constant - | ImagE (Expr v) -- ^ multiplies by the field extension generator X - | SumE [Expr v] -- ^ sum of expressions - | ProdE [Expr v] -- ^ product of expressions - | PowE (Expr v) Int -- ^ exponentiation - deriving (Eq) -- ,Show) + = VarE v -- ^ a variable + | LitE F -- ^ constant literal + | AddE (Expr v) (Expr v) -- ^ addition + | SubE (Expr v) (Expr v) -- ^ subtraction + | MulE (Expr v) (Expr v) -- ^ multiplication + | ImgE (Expr v) -- ^ multiplies by the field extension generator X + deriving (Eq,Show) -instance Pretty var => Show (Expr var) where show = pretty +-- instance Pretty var => Show (Expr var) where show = pretty -- | Degree of the expression exprDegree :: Expr var -> Int exprDegree = go where go expr = case expr of - VarE _ -> 1 - LitE _ -> 0 - ScaleE _ e -> go e - ImagE e -> go e - SumE es -> if null es then 0 else maximum (map go es) - ProdE es -> sum (map go es) - PowE e n -> n * go e + VarE _ -> 1 + LitE _ -> 0 + AddE e1 e2 -> max (go e1) (go e2) + SubE e1 e2 -> max (go e1) (go e2) + MulE e1 e2 -> go e1 + go e2 + ImgE e -> go e instance Num (Expr var) where fromInteger = LitE . fromInteger negate = negE - (+) = addE - (-) = subE - (*) = mulE + (+) = AddE + (-) = SubE + (*) = MulE abs = error "Expr/abs" signum = error "Expr/signum" negE :: Expr var -> Expr var -negE (ScaleE s e) = ScaleE (negate s) e -negE e = ScaleE (-1) e +negE e = SubE (LitE 0) e + +{- +(^) :: Expr var -> Int -> Expr var +(^) = PowE addE :: Expr var -> Expr var -> Expr var -addE (SumE es) (SumE fs) = SumE (es++fs ) -addE e (SumE fs) = SumE (e : fs ) -addE (SumE es) f = SumE (es++[f]) -addE e f = SumE [e,f] - -subE :: Expr var -> Expr var -> Expr var -subE e f = addE e (negate f) - -sclE :: F -> Expr var -> Expr var -sclE s (ScaleE t e) = sclE (s*t) e -sclE s e = ScaleE s e +addE = AddE mulE :: Expr var -> Expr var -> Expr var -mulE (ScaleE s e) (ScaleE t f) = sclE (s*t) (mulE e f) -mulE (ScaleE s e) f = sclE s (mulE e f) -mulE (LitE s) f = sclE s f -mulE e (LitE t) = sclE t e -mulE e (ScaleE t f) = sclE t (mulE e f) -mulE (ProdE es) (ProdE fs) = ProdE (es++fs ) -mulE e (ProdE fs) = ProdE (e : fs ) -mulE (ProdE es) f = ProdE (es++[f]) -mulE e f = ProdE [e,f] +mulE = MulE +-} -------------------------------------------------------------------------------- -- * pretty printing --- | TODO: maybe move this somewhere else -class Pretty a where - prettyPrec :: Int -> a -> (String -> String) - -pretty :: Pretty a => a -> String -pretty x = prettyPrec 0 x "" - -instance Pretty F where prettyPrec _ x = shows x -instance Pretty FExt where prettyPrec _ x = shows x - instance Pretty var => Pretty (Expr var) where prettyPrec d expr = case expr of - VarE v -> prettyPrec 0 v - LitE x -> prettyPrec 0 x - ScaleE s e -> prettyPrec 0 s . showString " * " . showParen (d > mul_prec) (prettyPrec mul_prec e) - ImagE e -> showString "X*" . showParen (d > mul_prec) (prettyPrec mul_prec e) - SumE es -> showParen (d > add_prec) $ intercalates " + " $ map (prettyPrec add_prec) es - ProdE es -> showParen (d > mul_prec) $ intercalates " * " $ map (prettyPrec mul_prec) es - PowE e k -> showParen (d > pow_prec) $ (prettyPrec pow_prec e) . showString ("^" ++ show k) + VarE v -> prettyPrec 0 v + LitE x -> prettyPrec 0 x + AddE e1 e2 -> showParen (d > add_prec) $ prettyPrec add_prec e1 . showString " + " . prettyPrec (add_prec+1) e2 + SubE e1 e2 -> showParen (d > add_prec) $ prettyPrec add_prec e1 . showString " - " . prettyPrec (add_prec+1) e2 + MulE e1 e2 -> showParen (d > mul_prec) $ prettyPrec add_prec e1 . showString " * " . prettyPrec (mul_prec+1) e2 + ImgE e -> showParen (d > mul_prec) $ showString "X*" . (prettyPrec mul_prec e) where add_prec = 5 mul_prec = 6 - pow_prec = 7 - intercalates sep = go where - go [] = id - go [x] = x - go (x:xs) = x . showString sep . go xs + -- pow_prec = 7 + -- intercalates sep = go where + -- go [] = id + -- go [x] = x + -- go (x:xs) = x . showString sep . go xs -------------------------------------------------------------------------------- -- * Evaluation @@ -127,12 +103,11 @@ instance EvalField FExt where fromGoldilocks = fromBase evalExprWith :: (var -> FExt) -> Expr var -> FExt evalExprWith evalVar expr = go expr where go e = case e of - VarE v -> evalVar v - LitE x -> fromBase x - ScaleE s e -> fromBase s * go e - ImagE e -> (MkExt 0 1) * go e - SumE es -> sum (map go es) - ProdE es -> product (map go es) - PowE e n -> powExt (go e) (fromIntegral n) + VarE v -> evalVar v + LitE x -> fromBase x + AddE e1 e2 -> go e1 + go e2 + SubE e1 e2 -> go e1 - go e2 + MulE e1 e2 -> go e1 * go e2 + ImgE e -> (MkExt 0 1) * go e -------------------------------------------------------------------------------- diff --git a/src/Algebra/Goldilocks.hs b/src/Algebra/Goldilocks.hs index 2290b1c..e3e88ad 100644 --- a/src/Algebra/Goldilocks.hs +++ b/src/Algebra/Goldilocks.hs @@ -14,6 +14,7 @@ import Data.Word import Data.Ratio import Data.Array +import Text.Show import Text.Printf import System.Random @@ -21,6 +22,8 @@ import System.Random import GHC.Generics import Data.Aeson ( ToJSON(..), FromJSON(..) ) +import Misc.Pretty + -------------------------------------------------------------------------------- type F = Goldilocks @@ -78,6 +81,8 @@ instance Show Goldilocks where show (Goldilocks x) = show x -- decimal -- show (Goldilocks x) = printf "0x%016x" x -- hex +instance Pretty Goldilocks where prettyPrec _ x = shows x + -------------------------------------------------------------------------------- instance ToJSON Goldilocks where diff --git a/src/Algebra/GoldilocksExt.hs b/src/Algebra/GoldilocksExt.hs index 2809287..1dee91c 100644 --- a/src/Algebra/GoldilocksExt.hs +++ b/src/Algebra/GoldilocksExt.hs @@ -12,10 +12,12 @@ module Algebra.GoldilocksExt where import Data.Bits import Data.Ratio +import Text.Show import Data.Aeson ( ToJSON(..), FromJSON(..) ) import Algebra.Goldilocks +import Misc.Pretty -------------------------------------------------------------------------------- @@ -31,6 +33,12 @@ fromBase x = MkExt x 0 instance Show GoldilocksExt where show (MkExt real imag) = "(" ++ show real ++ " + X*" ++ show imag ++ ")" +instance Pretty GoldilocksExt where + prettyPrec d (MkExt real imag) + | imag == 0 = prettyPrec 0 real + | otherwise = showParen (d > 5) + $ prettyPrec 0 real . showString " + X*" . prettyPrec 0 imag + instance ToJSON GoldilocksExt where toJSON (MkExt a b) = toJSON (a,b) @@ -81,3 +89,11 @@ powExt x e _ -> go (acc*s) (sqrExt s) (shiftR expo 1) -------------------------------------------------------------------------------- + +rndExt :: IO FExt +rndExt = do + x <- rndF + y <- rndF + return (MkExt x y) + +-------------------------------------------------------------------------------- diff --git a/src/Gate/Computation.hs b/src/Gate/Computation.hs new file mode 100644 index 0000000..c275bb5 --- /dev/null +++ b/src/Gate/Computation.hs @@ -0,0 +1,160 @@ + +-- | We have to describe the contraints as computations with local definitions +-- Without local definitions the equations would just blow up +-- in the case of the Poseidon gate for example +-- + +{-# LANGUAGE StrictData, DeriveFunctor, GADTs, RecordWildCards #-} +module Gate.Computation where + +-------------------------------------------------------------------------------- + +import Prelude hiding ( (^) ) + +import Control.Applicative +import Control.Monad + +import Data.Array +import Data.List +import Text.Show + +import Data.IntMap (IntMap) +import qualified Data.IntMap as IntMap + +import Algebra.Goldilocks +import Algebra.GoldilocksExt +import Algebra.Expr + +import Gate.Vars +import Misc.Pretty + +-------------------------------------------------------------------------------- +-- * Operational monad + +data Instr var a where + Let :: String -> Expr var -> Instr var var + Commit :: Expr var -> Instr var () + +data Program instr a where + Bind :: Program instr a -> (a -> Program instr b) -> Program instr b + Return :: a -> Program instr a + Instr :: instr a -> Program instr a + +instance Functor (Program instr) where + fmap = liftM + +instance Applicative (Program instr) where + (<*>) = ap + pure = Return + +instance Monad (Program instr) where + (>>=) = Bind + +type Var_ = Var PlonkyVar +type Expr_ = Expr Var_ +type Def_ = LocalDef Var_ +--type Instr_ a = Instr Var_ a + +-------------------------------------------------------------------------------- + +-- | Our computation monad +type Compute a = Program (Instr Var_) a + +let_ :: String -> Expr Var_ -> Compute (Expr Var_) +let_ name rhs = VarE <$> Instr (Let name rhs) + +commit :: Expr Var_ -> Compute () +commit what = Instr (Commit what) + +commitList :: [Expr Var_] -> Compute () +commitList = mapM_ commit + +-------------------------------------------------------------------------------- +-- | Straightline programs + +data LocalDef v + = MkLocalDef Int String (Expr v) + deriving (Eq,Show) + +instance Pretty v => Pretty (LocalDef v) where + prettyPrec _ (MkLocalDef k name rhs) = showString ("_" ++ name ++ show k) . showString " := " . prettyPrec 0 rhs + +-- | A straightline program encoding the computation of constraints +data StraightLine = MkStraightLine + { localdefs :: [LocalDef Var_] -- ^ local definitions, in reverse order + , commits :: [Expr_] -- ^ committed constraints, in reverse order + , counter :: Int -- ^ fresh variable counter + } + deriving Show + +emptyStraightLine :: StraightLine +emptyStraightLine = MkStraightLine [] [] 0 + +printStraightLine :: StraightLine -> IO () +printStraightLine (MkStraightLine{..}) = do + forM_ (reverse localdefs) $ \def -> putStrLn (pretty def) + forM_ (reverse commits ) $ \expr -> putStrLn $ "constraint 0 == " ++ (pretty expr) + +compileToStraightLine :: Compute () -> StraightLine +compileToStraightLine = fst . go emptyStraightLine where + go :: StraightLine -> Compute a -> (StraightLine,a) + go state instr = case instr of + Return x -> (state,x) + Bind this rest -> let (state',x) = go state this in go state' (rest x) + Instr this -> case state of + MkStraightLine{..} -> case this of + Commit what -> let state' = MkStraightLine localdefs (what:commits) counter + in (state', ()) + Let name rhs -> let def = MkLocalDef counter name rhs + state' = MkStraightLine (def:localdefs) commits (counter+1) + in (state', LocalVar counter name) + +-------------------------------------------------------------------------------- + +type Scope a = IntMap a + +emptyScope :: Scope a +emptyScope = IntMap.empty + +-- | Run a \"straightline program\", resulting in list of contraints evaluations +runStraightLine :: EvaluationVars FExt -> StraightLine -> [FExt] +runStraightLine = runStraightLine' emptyScope + +runStraightLine' :: Scope FExt -> EvaluationVars FExt -> StraightLine -> [FExt] +runStraightLine' iniScope vars (MkStraightLine{..}) = result where + finalScope = foldl' worker iniScope (reverse localdefs) + result = evalConstraints finalScope vars (reverse commits) + worker !scope (MkLocalDef i _ rhs) = IntMap.insert i (evalConstraint scope vars rhs) scope + +-------------------------------------------------------------------------------- +-- * Evaluation + +type Constraint = Expr_ + +-- | List of all data (one "row") we need to evaluate a gate constraint +-- +-- Typically this will be the evaluations of the column polynomials at @zeta@ +data EvaluationVars a = MkEvaluationVars + { local_selectors :: Array Int a -- ^ the selectors + , local_constants :: Array Int a -- ^ the circuit constants + , local_wires :: Array Int a -- ^ the advice wires (witness) + , public_inputs_hash :: [F] -- ^ only used in @PublicInputGate@ + } + deriving (Show,Functor) + +evalConstraint :: Scope FExt -> EvaluationVars FExt -> Constraint -> FExt +evalConstraint scope (MkEvaluationVars{..}) expr = evalExprWith f expr where + f var = case var of + LocalVar i n -> case IntMap.lookup i scope of + Just y -> y + Nothing -> error $ "variable _" ++ n ++ show i ++ " not in scope" + ProofVar v -> case v of + SelV k -> local_selectors ! k + ConstV k -> local_constants ! k + WireV k -> local_wires ! k + PIV k -> fromBase (public_inputs_hash !! k) + +evalConstraints :: Scope FExt -> EvaluationVars FExt -> [Constraint] -> [FExt] +evalConstraints scope vars = map (evalConstraint scope vars) + +-------------------------------------------------------------------------------- diff --git a/src/Gate/Constraints.hs b/src/Gate/Constraints.hs index 12d701b..ed593ef 100644 --- a/src/Gate/Constraints.hs +++ b/src/Gate/Constraints.hs @@ -5,123 +5,92 @@ -- of constraints. -- -{-# LANGUAGE StrictData, RecordWildCards #-} +{-# LANGUAGE StrictData, DeriveFunctor, GADTs, RecordWildCards #-} module Gate.Constraints where -------------------------------------------------------------------------------- -import Data.Array -import Data.Char +import Prelude hiding ( (^) ) +import Data.Array hiding (range) +import Data.Char import Text.Show +import Data.IntMap (IntMap) +import qualified Data.IntMap as IntMap + import Algebra.Goldilocks import Algebra.GoldilocksExt import Algebra.Expr import Gate.Base +import Gate.Vars +import Gate.Computation +import Gate.Poseidon --------------------------------------------------------------------------------- --- * Constraint expressions - --- | These index into a row + public input -data Var - = SelV Int -- ^ selector variable - | ConstV Int -- ^ constant variable - | WireV Int -- ^ wire variable - | PIV Int -- ^ public input hash variable - deriving (Eq,Ord,Show) - -instance Pretty Var where - prettyPrec _ v = case v of - SelV k -> showString ("s" ++ show k) - ConstV k -> showString ("c" ++ show k) - WireV k -> showString ("w" ++ show k) - PIV k -> showString ("h" ++ show k) - --------------------------------------------------------------------------------- - --- | List of all data (one "row") we need to evaluate a gate constraint --- --- Typically this will be the evaluations of the column polynomials at @zeta@ -data EvaluationVars a = MkEvaluationVars - { local_selectors :: Array Int a -- ^ the selectors - , local_constants :: Array Int a -- ^ the circuit constants - , local_wires :: Array Int a -- ^ the advice wires (witness) - , public_inputs_hash :: [F] -- ^ only used in @PublicInputGate@ - } - deriving (Show) - --------------------------------------------------------------------------------- --- * Evaluation - -evalExpr :: Expr Var -> EvaluationVars FExt -> FExt -evalExpr expr (MkEvaluationVars{..}) = evalExprWith f expr where - f v = case v of - SelV k -> local_selectors ! k - ConstV k -> local_constants ! k - WireV k -> local_wires ! k - PIV k -> fromBase (public_inputs_hash !! k) - -------------------------------------------------------------------------------- -- * Gate constraints -- | Returns the (symbolic) constraints for the given gate -gateConstraints :: Gate -> [Expr Var] -gateConstraints gate = +-- +gateProgram :: Gate -> StraightLine +gateProgram = compileToStraightLine . gateComputation + +gateComputation :: Gate -> Compute () +gateComputation gate = case gate of -- `w[i] - c0*x[i]*y[i] - c1*z[i] = 0` ArithmeticGate num_ops - -> [ ww (j+3) - cc 0 * ww j * ww (j+1) - cc 1 * ww (j+2) | i<-range num_ops, let j = 4*i ] + -> commitList [ wire (j+3) - cnst 0 * wire j * wire (j+1) - cnst 1 * wire (j+2) | i<-range num_ops, let j = 4*i ] -- same but consecutive witness variables make up an extension field element ArithmeticExtensionGate num_ops - -> [ wwExt (j+6) - cc 0 * wwExt j * wwExt (j+2) - cc 1 * wwExt (j+4) | i<-range num_ops, let j = 8*i ] + -> commitList [ wireExt (j+6) - cnst 0 * wireExt j * wireExt (j+2) - cnst 1 * wireExt (j+4) | i<-range num_ops, let j = 8*i ] -- `sum b^i * limbs[i] - out = 0`, and `0 <= limb[i] < B` is enforced BaseSumGate num_limbs base - -> let limb i = ww (i+1) + -> let limb i = wire (i+1) horner = go 0 where go k = if k < num_limbs-1 then limb k + fromIntegral base * go (k+1) else limb k - sum_eq = horner - ww 0 - range_eq i = ProdE [ limb i - fromIntegral k | k<-[0..base-1] ] - in sum_eq : [ range_eq i | i<-range num_limbs ] + sum_eq = horner - wire 0 + range_eq i = product [ limb i - fromIntegral k | k<-[0..base-1] ] + in commitList $ sum_eq : [ range_eq i | i<-range num_limbs ] CosetInterpolationGate subgroup_bits coset_degree barycentric_weights -> todo -- `c[i] - x[i] = 0` ConstantGate num_consts - -> [ cc i - ww i | i <- range num_consts ] + -> commitList [ cnst i - wire i | i <- range num_consts ] -- computes `out = base ^ (sum 2^i e_i)` -- order of witness variables: [ base, e[0],...,e[n-1], output, t[0]...t[n-1] ] ExponentiationGate num_power_bits - -> let base = ww 0 - exp_bit i = ww (i+1) - out = ww (num_power_bits+1) + -> let base = wire 0 + exp_bit i = wire (i+1) + out = wire (num_power_bits+1) tmp_val 0 = 1 - tmp_val i = ww (num_power_bits+1+i) + tmp_val i = wire (num_power_bits+1+i) cur_bit i = exp_bit (num_power_bits - 1 - i) eq i = tmp_val (i-1) * (cur_bit i * base + 1 - cur_bit i) - tmp_val i - in [ eq i | i <- range num_power_bits ] ++ [ out - tmp_val (num_power_bits-1) ] + in commitList $ [ eq i | i <- range num_power_bits ] ++ [ out - tmp_val (num_power_bits-1) ] -- lookups are handled specially, no constraints here - LookupGate num_slots lut_hash -> [] - LookupTableGate num_slots lut_hash last_lut_row -> [] + LookupGate num_slots lut_hash -> return () + LookupTableGate num_slots lut_hash last_lut_row -> return () -- `z[i] - c0*x[i]*y[i] = 0`, and two witness cells make up an extension field element MulExtensionGate num_ops - -> [ wwExt (j+4) - cc 0 * wwExt j * wwExt (j+2) | i<-range num_ops, let j = 6*i ] + -> commitList [ wireExt (j+4) - cnst 0 * wireExt j * wireExt (j+2) | i<-range num_ops, let j = 6*i ] - NoopGate -> [] + NoopGate -> return () -- equality with "hardcoded" hash components PublicInputGate - -> [ hh i - ww i | i <- range 4 ] + -> commitList [ hash i - wire i | i <- range 4 ] PoseidonGate hash_width -> case hash_width of - 12 -> todo -- poseidonGateConstraints + 12 -> poseidonGateConstraints k -> error ( "gateConstraints/PoseidonGate: unsupported width " ++ show k) PoseidonMdsGate hash_width -> case hash_width of @@ -143,12 +112,4 @@ gateConstraints gate = todo = error $ "gateConstraints: gate `" ++ takeWhile isAlpha (show gate) ++ "` not yet implemented" - range k = [0..k-1] - - ww i = VarE (WireV i) -- witness variable - cc i = VarE (ConstV i) -- constant variable - hh i = VarE (PIV i) -- public input hash component - - wwExt i = ww i + ImagE (ww (i+1)) -- use two consecutive variables as an extension field element - --------------------------------------------------------------------------------- \ No newline at end of file +-------------------------------------------------------------------------------- diff --git a/src/Gate/Parser.hs b/src/Gate/Parser.hs index 0064072..9959d0a 100644 --- a/src/Gate/Parser.hs +++ b/src/Gate/Parser.hs @@ -1,5 +1,6 @@ --- | Gates are encoded as strings produced by ad-hoc of modification of Rust textual serialization... +-- | Gates are encoded as strings produced by some ad-hoc modifications +-- of the default Rust textual serialization... -- -- ... so we have to parse /that/ -- diff --git a/src/Gate/Poseidon.hs b/src/Gate/Poseidon.hs new file mode 100644 index 0000000..20c9fd4 --- /dev/null +++ b/src/Gate/Poseidon.hs @@ -0,0 +1,134 @@ + +-- | Plonky2's Poseidon gate +-- + +{-# LANGUAGE StrictData, RecordWildCards #-} +module Gate.Poseidon where + +-------------------------------------------------------------------------------- + +import Data.Array hiding (range) +import Data.Char +import Data.Foldable + +import Control.Monad +import Control.Monad.State.Strict + +import Algebra.Goldilocks +import Algebra.Expr + +import Gate.Vars +import Gate.Computation +import Hash.Constants + +-------------------------------------------------------------------------------- + +sbox :: Expr Var_ -> Compute (Expr Var_) +sbox x = do + x2 <- let_ "x2_" (x *x ) + x3 <- let_ "x3_" (x *x2) + x4 <- let_ "x4_" (x2*x2) + return (x3*x4) + +flipFoldM :: (Foldable t, Monad m) => b -> t a -> (b -> a -> m b) -> m b +flipFoldM s0 list action = foldM action s0 list + +flipFoldM_ :: (Foldable t, Monad m) => b -> t a -> (b -> a -> m b) -> m () +flipFoldM_ s0 list action = void (foldM action s0 list) + +-- | Poseidon state +type PS = [Expr Var_] + +-------------------------------------------------------------------------------- + +poseidonGateConstraints :: Compute () +poseidonGateConstraints = do + + -- merkle swap + let input_lhs i = input i + let input_rhs i = input (i+4) + commit $ (1 - swap_flag) * swap_flag + commitList [ swap_flag * (input_rhs i - input_lhs i) - delta i | i <- range 4 ] + + -- swapped inputs + let state0 :: PS + state0 + = [ input_lhs i + delta i | i <- [0.. 3] ] + ++ [ input_rhs (i-4) - delta (i-4) | i <- [4.. 7] ] + ++ [ input i | i <- [8..11] ] + + -- initial full rounds + phase1 <- flipFoldM state0 [0..3] $ \state1 r -> do + let state2 = plus_rc r state1 + state3 <- if r == 0 + then return state2 + else do + let sbox_in = initial_sbox_in r + commitList [ (state2!!i) - sbox_in i | i <- range 12 ] + return [ sbox_in i | i <- range 12 ] + state4 <- mapM sbox state3 + return $ mds state4 + + -- partial rounds + let state' = zipWith (+) phase1 (map LitE $ elems fast_PARTIAL_FIRST_ROUND_CONSTANT) + let state'' = mdsInitPartial state' + phase2 <- flipFoldM state'' [0..21] $ \state1 r -> do + let sbox_in = partial_sbox_in r + commit $ (state1!!0) - sbox_in + y <- sbox sbox_in + let z = if r < 21 then y + LitE (fast_PARTIAL_ROUND_CONSTANTS!r) else y + let state2 = z : tail state1 + return $ mdsFastPartial r state2 + + -- final full rounds + phase3 <- flipFoldM phase2 [0..3] $ \state1 r -> do + let state2 = plus_rc (r+26) state1 + let sbox_in = final_sbox_in r + commitList [ (state2!!i) - sbox_in i | i <- range 12 ] + state3 <- mapM sbox [ sbox_in i | i <- range 12 ] + return $ mds state3 + + -- constraint the output to be the result + commitList [ phase3!!i - output i | i <- range 12 ] + + where + + -- multiply by the MDS matrix + mds :: PS -> PS + mds state = + [ sum [ LitE (mdsMatrixCoeff i j) * x | (j,x) <- zip [0..] state ] + | i <- range 12 + ] + + dotProd :: PS -> [F] -> Expr Var_ + dotProd es cs = sum $ zipWith (\e c -> e * LitE c) es cs + + mdsInitPartial :: PS -> PS + mdsInitPartial (first:rest) + = first + : [ sum [ LitE (partialMdsMatrixCoeff i j) * x | (j,x) <- zip [0..] rest ] + | i <- range 11 + ] + + mdsFastPartial :: Int -> PS -> PS + mdsFastPartial r state@(s0:rest) = res where + m0 = mdsMatrixCoeff 0 0 + cs = m0 : elems (fast_PARTIAL_ROUND_W_HATS!r) + d = dotProd state cs + res = d : [ x + s0 * LitE t | (x,t) <- zip rest (elems $ fast_PARTIAL_ROUND_VS!r) ] + + -- add round constants + plus_rc :: Int -> PS -> PS + plus_rc r state = zipWith (+) state (map LitE $ elems (all_ROUND_CONSTANTS!r)) + + -- witness variables + input i = wire i + output i = wire (i+12) + swap_flag = wire 24 + delta i = wire (25+i) + initial_sbox_in r i = wire (29 + 12*(r-1) + i) -- 0 < r < 4 , 0 <= i < 12 + partial_sbox_in r = wire (29 + 36 + r) -- 0 <= r < 22 + final_sbox_in r i = wire (29 + 36 + 22 + 12*r + i) -- 0 <= r < 4 , 0 <= i < 12 + +-------------------------------------------------------------------------------- + diff --git a/src/Gate/Vars.hs b/src/Gate/Vars.hs new file mode 100644 index 0000000..593f05e --- /dev/null +++ b/src/Gate/Vars.hs @@ -0,0 +1,58 @@ + +-- | Variables appearing in gate constraints + +module Gate.Vars where + +-------------------------------------------------------------------------------- + +import Text.Show + +import Algebra.Expr +import Misc.Pretty + +-------------------------------------------------------------------------------- +-- * Constraint variables + +-- | These index into a row + public input +data PlonkyVar + = SelV Int -- ^ selector variable + | ConstV Int -- ^ constant variable + | WireV Int -- ^ wire variable + | PIV Int -- ^ public input hash variable + deriving (Eq,Ord,Show) + +instance Pretty PlonkyVar where + prettyPrec _ v = case v of + SelV k -> showString ("s" ++ show k) + ConstV k -> showString ("c" ++ show k) + WireV k -> showString ("w" ++ show k) + PIV k -> showString ("h" ++ show k) + +-------------------------------------------------------------------------------- +-- * Variables + +data Var v + = LocalVar Int String -- ^ a temporary variable + | ProofVar v -- ^ a proof variable (eg. witness, constant, selector) + deriving (Eq,Show) + +instance Pretty v => Pretty (Var v) where + prettyPrec d var = case var of + LocalVar k name -> showString ("_" ++ name ++ show k) + ProofVar v -> prettyPrec d v + +-------------------------------------------------------------------------------- +-- * Convenience + +range :: Int -> [Int] +range k = [0..k-1] + +wire, cnst, hash :: Int -> Expr (Var PlonkyVar) +wire i = VarE $ ProofVar $ WireV i -- witness variable +cnst i = VarE $ ProofVar $ ConstV i -- constant variable +hash i = VarE $ ProofVar $ PIV i -- public input hash component + +wireExt :: Int -> Expr (Var PlonkyVar) +wireExt i = wire i + ImgE (wire (i+1)) -- use two consecutive variables as an extension field element + +-------------------------------------------------------------------------------- diff --git a/src/Hash/Constants.hs b/src/Hash/Constants.hs index dcebd64..4921303 100644 --- a/src/Hash/Constants.hs +++ b/src/Hash/Constants.hs @@ -40,10 +40,8 @@ fast_PARTIAL_ROUND_CONSTANTS = listArray (0,21) , 0x1aca78f31c97c876 , 0x0 ] -{- - -fast_PARTIAL_ROUND_VS :: [Array Int F] -fast_PARTIAL_ROUND_VS = map (listArray (0,10)) +fast_PARTIAL_ROUND_VS :: Array Int (Array Int F) +fast_PARTIAL_ROUND_VS = listArray (0,21) $ map (listArray (0,10)) [ [0x94877900674181c3, 0xc6c67cc37a2a2bbd, 0xd667c2055387940f, 0x0ba63a63e94b5ff0, 0x99460cc41b8f079f, 0x7ff02375ed524bb3, 0xea0870b47a8caf0e, 0xabcad82633b7bc9d, 0x3b8d135261052241, 0xfb4515f5e5b0d539, 0x3ee8011c2b37f77c ] , [0x0adef3740e71c726, 0xa37bf67c6f986559, 0xc6b16f7ed4fa1b00, 0x6a065da88d8bfc3c, 0x4cabc0916844b46f, 0x407faac0f02e78d1, 0x07a786d9cf0852cf, 0x42433fb6949a629a, 0x891682a147ce43b0, 0x26cfd58e7b003b55, 0x2bbf0ed7b657acb3 ] , [0x481ac7746b159c67, 0xe367de32f108e278, 0x73f260087ad28bec, 0x5cfc82216bc1bdca, 0xcaccc870a2663a0e, 0xdb69cd7b4298c45d, 0x7bc9e0c57243e62d, 0x3cc51c5d368693ae, 0x366b4e8cc068895b, 0x2bd18715cdabbca4, 0xa752061c4f33b8cf ] @@ -68,8 +66,8 @@ fast_PARTIAL_ROUND_VS = map (listArray (0,10)) , [0x0000000000000014, 0x0000000000000022, 0x0000000000000012, 0x0000000000000027, 0x000000000000000d, 0x000000000000000d, 0x000000000000001c, 0x0000000000000002, 0x0000000000000010, 0x0000000000000029, 0x000000000000000f ] ] -fast_PARTIAL_ROUND_W_HATS :: [Array Int F] -fast_PARTIAL_ROUND_W_HATS = map (listArray (0,10)) +fast_PARTIAL_ROUND_W_HATS :: Array Int (Array Int F) +fast_PARTIAL_ROUND_W_HATS = listArray (0,21) $ map (listArray (0,10)) [ [0x3d999c961b7c63b0, 0x814e82efcd172529, 0x2421e5d236704588, 0x887af7d4dd482328, 0xa5e9c291f6119b27, 0xbdc52b2676a4b4aa, 0x64832009d29bcf57, 0x09c4155174a552cc, 0x463f9ee03d290810, 0xc810936e64982542, 0x043b1c289f7bc3ac ] , [0x673655aae8be5a8b, 0xd510fe714f39fa10, 0x2c68a099b51c9e73, 0xa667bfa9aa96999d, 0x4d67e72f063e2108, 0xf84dde3e6acda179, 0x40f9cc8c08f80981, 0x5ead032050097142, 0x6591b02092d671bb, 0x00e18c71963dd1b7, 0x8a21bcd24a14218a ] , [0x202800f4addbdc87, 0xe4b5bdb1cc3504ff, 0xbe32b32a825596e7, 0x8e0f68c5dc223b9a, 0x58022d9e1c256ce3, 0x584d29227aa073ac, 0x8b9352ad04bef9e7, 0xaead42a3f445ecbf, 0x3c667a1d833a3cca, 0xda6f61838efa1ffe, 0xe8f749470bd7c446 ] @@ -94,8 +92,6 @@ fast_PARTIAL_ROUND_W_HATS = map (listArray (0,10)) , [0x3abeb80def61cc85, 0x9d19c9dd4eac4133, 0x075a652d9641a985, 0x9daf69ae1b67e667, 0x364f71da77920a18, 0x50bd769f745c95b1, 0xf223d1180dbbf3fc, 0x2f885e584e04aa99, 0xb69a0fa70aea684a, 0x09584acaa6e062a0, 0x0bc051640145b19b ] ] --} - -- ^ NB: This is in ROW-major order to support cache-friendly pre-multiplication. fast_PARTIAL_ROUND_INITIAL_MATRIX :: Array (Int,Int) F fast_PARTIAL_ROUND_INITIAL_MATRIX = listArray ((0,0),(10,10)) $ concat @@ -112,6 +108,9 @@ fast_PARTIAL_ROUND_INITIAL_MATRIX = listArray ((0,0),(10,10)) $ concat , [0xd841e8ef9dde8ba0, 0x156048ee7a738154, 0x85418a9fef8a9890, 0x64dd936da878404d, 0x726af914971c1374, 0x7f8e41e0b0a6cdff, 0xf97abba0dffb6c50, 0xf4a437f2888ae909, 0xdcedab70f40718ba, 0xe796d293a47a64cb, 0x80772dc2645b280b ] ] +partialMdsMatrixCoeff :: Int -> Int -> F +partialMdsMatrixCoeff i j = fast_PARTIAL_ROUND_INITIAL_MATRIX ! (j,i) + partition12 :: [a] -> [[a]] partition12 = go where go [] = [] diff --git a/src/Misc/Pretty.hs b/src/Misc/Pretty.hs new file mode 100644 index 0000000..dfcd301 --- /dev/null +++ b/src/Misc/Pretty.hs @@ -0,0 +1,15 @@ + +-- | Precedence-aware pretty-printing + +module Misc.Pretty where + +-------------------------------------------------------------------------------- + +-- | See "Text.Show" +class Pretty a where + prettyPrec :: Int -> a -> (String -> String) + +pretty :: Pretty a => a -> String +pretty x = prettyPrec 0 x "" + +-------------------------------------------------------------------------------- diff --git a/src/Test/Witness.hs b/src/Test/Witness.hs new file mode 100644 index 0000000..86e4ab0 --- /dev/null +++ b/src/Test/Witness.hs @@ -0,0 +1,82 @@ + +-- | We can test the gate constraints on actual witness rows during development + +{-# LANGUAGE StrictData, DeriveGeneric, DeriveAnyClass, RecordWildCards #-} +module Test.Witness where + +-------------------------------------------------------------------------------- + +import Data.Array + +import Data.Aeson +import GHC.Generics +import qualified Data.ByteString.Lazy as L + +import Algebra.Goldilocks +import Algebra.GoldilocksExt + +import Gate.Base +import Gate.Computation +import Gate.Constraints + +-------------------------------------------------------------------------------- + +-- | As exported into JSON by our Plonky2 fork +data Witness = MkWitness + { gates :: [String] + , selector_vector :: [Int] + , selector_columns :: [[F]] + , constants_columns :: [[F]] + , matrix :: [[F]] + } + deriving (Show,Generic,ToJSON,FromJSON) + +matrixRow :: [[F]] -> Int -> [F] +matrixRow matrix i = map (!!i) matrix + +witnessRow :: Witness -> Int -> EvaluationVars F +witnessRow (MkWitness{..}) i = MkEvaluationVars + { local_selectors = toArray (matrixRow selector_columns i) + , local_constants = toArray (matrixRow constants_columns i) + , local_wires = toArray (matrixRow matrix i) + , public_inputs_hash = [] + } + where + toArray xs = listArray (0, length xs - 1) xs + +loadWitness :: FilePath -> IO Witness +loadWitness fpath = do + txt <- L.readFile fpath + case decode txt of + Just w -> return w + Nothing -> fail "loadWitness: cannot parse JSON witness" + +-------------------------------------------------------------------------------- + +test_fibonacci = do + + witness <- loadWitness "../json/fibonacci_witness.json" + + let arith_row = witnessRow witness 0 + let posei_row = witnessRow witness 5 + let pubio_row = witnessRow witness 6 + let const_row = witnessRow witness 7 + + let arith_prg = gateProgram (ArithmeticGate 20) + let posei_prg = gateProgram (PoseidonGate 12) + let pubio_prg = gateProgram (PublicInputGate ) + let const_prg = gateProgram (ConstantGate 2) + + let arith_evals = runStraightLine (fmap fromBase arith_row) arith_prg + let const_evals = runStraightLine (fmap fromBase const_row) const_prg + let posei_evals = runStraightLine (fmap fromBase posei_row) posei_prg + + putStrLn $ "number of constants in ArithmeticGate = " ++ show (length arith_evals) + putStrLn $ "number of constants in ConstantGate = " ++ show (length const_evals) + putStrLn $ "number of constants in PosiedoGate = " ++ show (length posei_evals) + + print arith_evals + print const_evals + print posei_evals + +--------------------------------------------------------------------------------