completely rewrite the constraints; PoseidonGate seems to work now

This commit is contained in:
Balazs Komuves 2024-12-15 14:03:44 +01:00
parent 338163f56d
commit c949f3d3f2
No known key found for this signature in database
GPG Key ID: F63B7AEF18435562
12 changed files with 563 additions and 157 deletions

View File

@ -2,7 +2,7 @@
A standalone Plonky2 verifier A standalone Plonky2 verifier
----------------------------- -----------------------------
This is a (WIP) implementation of a Plonky2 verifier in Haskell. This is a (WIP) implementation of a Plonky2 verifier written in Haskell.
[Plonky2](https://github.com/0xPolygonZero/plonky2/) is a zero-knowledge proof [Plonky2](https://github.com/0xPolygonZero/plonky2/) is a zero-knowledge proof
system developed by Polygon Zero, optimized for recursive proofs. system developed by Polygon Zero, optimized for recursive proofs.
@ -15,7 +15,7 @@ Another goal is to be a basis for further tooling (for example:
estimating verifier costs, helping the design of recursive circuits, generating estimating verifier costs, helping the design of recursive circuits, generating
Plonky2 verifier circuits for other proof systems, etc) Plonky2 verifier circuits for other proof systems, etc)
Note: It's deliberately no a goal for this verifier to be efficient; instead we Note: It's deliberately not a goal for this verifier to be efficient; instead we
try to focus on simplicity. try to focus on simplicity.
@ -42,7 +42,7 @@ Supported gates:
- [x] MulExtensionGate - [x] MulExtensionGate
- [x] NoopGate - [x] NoopGate
- [x] PublicInputGate - [x] PublicInputGate
- [ ] PoseidonGate - [x] PoseidonGate
- [ ] PoseidonMdsGate - [ ] PoseidonMdsGate
- [ ] RandomAccessGate - [ ] RandomAccessGate
- [ ] ReducingGate - [ ] ReducingGate
@ -52,6 +52,6 @@ Optional features:
- [ ] Field extensions with degree higher than 2 - [ ] Field extensions with degree higher than 2
- [ ] Being parametric over the field choice - [ ] Being parametric over the field choice
- [ ] Supporting different hash function choices - [ ] Supporting different hash functions

View File

@ -7,6 +7,8 @@ module Algebra.Expr where
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
import Prelude hiding ( (^) )
import Data.Array import Data.Array
import Data.Char import Data.Char
@ -16,102 +18,76 @@ import Algebra.Goldilocks
import Algebra.GoldilocksExt import Algebra.GoldilocksExt
import Gate.Base import Gate.Base
import Misc.Pretty
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- * Polynomial expressions -- * Polynomial expressions
data Expr v data Expr v
= VarE v -- ^ a variable = VarE v -- ^ a variable
| LitE F -- ^ constant literal | LitE F -- ^ constant literal
| ScaleE F (Expr v) -- ^ linear scaling by a constant | AddE (Expr v) (Expr v) -- ^ addition
| ImagE (Expr v) -- ^ multiplies by the field extension generator X | SubE (Expr v) (Expr v) -- ^ subtraction
| SumE [Expr v] -- ^ sum of expressions | MulE (Expr v) (Expr v) -- ^ multiplication
| ProdE [Expr v] -- ^ product of expressions | ImgE (Expr v) -- ^ multiplies by the field extension generator X
| PowE (Expr v) Int -- ^ exponentiation deriving (Eq,Show)
deriving (Eq) -- ,Show)
instance Pretty var => Show (Expr var) where show = pretty -- instance Pretty var => Show (Expr var) where show = pretty
-- | Degree of the expression -- | Degree of the expression
exprDegree :: Expr var -> Int exprDegree :: Expr var -> Int
exprDegree = go where exprDegree = go where
go expr = case expr of go expr = case expr of
VarE _ -> 1 VarE _ -> 1
LitE _ -> 0 LitE _ -> 0
ScaleE _ e -> go e AddE e1 e2 -> max (go e1) (go e2)
ImagE e -> go e SubE e1 e2 -> max (go e1) (go e2)
SumE es -> if null es then 0 else maximum (map go es) MulE e1 e2 -> go e1 + go e2
ProdE es -> sum (map go es) ImgE e -> go e
PowE e n -> n * go e
instance Num (Expr var) where instance Num (Expr var) where
fromInteger = LitE . fromInteger fromInteger = LitE . fromInteger
negate = negE negate = negE
(+) = addE (+) = AddE
(-) = subE (-) = SubE
(*) = mulE (*) = MulE
abs = error "Expr/abs" abs = error "Expr/abs"
signum = error "Expr/signum" signum = error "Expr/signum"
negE :: Expr var -> Expr var negE :: Expr var -> Expr var
negE (ScaleE s e) = ScaleE (negate s) e negE e = SubE (LitE 0) e
negE e = ScaleE (-1) e
{-
(^) :: Expr var -> Int -> Expr var
(^) = PowE
addE :: Expr var -> Expr var -> Expr var addE :: Expr var -> Expr var -> Expr var
addE (SumE es) (SumE fs) = SumE (es++fs ) addE = AddE
addE e (SumE fs) = SumE (e : fs )
addE (SumE es) f = SumE (es++[f])
addE e f = SumE [e,f]
subE :: Expr var -> Expr var -> Expr var
subE e f = addE e (negate f)
sclE :: F -> Expr var -> Expr var
sclE s (ScaleE t e) = sclE (s*t) e
sclE s e = ScaleE s e
mulE :: Expr var -> Expr var -> Expr var mulE :: Expr var -> Expr var -> Expr var
mulE (ScaleE s e) (ScaleE t f) = sclE (s*t) (mulE e f) mulE = MulE
mulE (ScaleE s e) f = sclE s (mulE e f) -}
mulE (LitE s) f = sclE s f
mulE e (LitE t) = sclE t e
mulE e (ScaleE t f) = sclE t (mulE e f)
mulE (ProdE es) (ProdE fs) = ProdE (es++fs )
mulE e (ProdE fs) = ProdE (e : fs )
mulE (ProdE es) f = ProdE (es++[f])
mulE e f = ProdE [e,f]
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- * pretty printing -- * pretty printing
-- | TODO: maybe move this somewhere else
class Pretty a where
prettyPrec :: Int -> a -> (String -> String)
pretty :: Pretty a => a -> String
pretty x = prettyPrec 0 x ""
instance Pretty F where prettyPrec _ x = shows x
instance Pretty FExt where prettyPrec _ x = shows x
instance Pretty var => Pretty (Expr var) where instance Pretty var => Pretty (Expr var) where
prettyPrec d expr = prettyPrec d expr =
case expr of case expr of
VarE v -> prettyPrec 0 v VarE v -> prettyPrec 0 v
LitE x -> prettyPrec 0 x LitE x -> prettyPrec 0 x
ScaleE s e -> prettyPrec 0 s . showString " * " . showParen (d > mul_prec) (prettyPrec mul_prec e) AddE e1 e2 -> showParen (d > add_prec) $ prettyPrec add_prec e1 . showString " + " . prettyPrec (add_prec+1) e2
ImagE e -> showString "X*" . showParen (d > mul_prec) (prettyPrec mul_prec e) SubE e1 e2 -> showParen (d > add_prec) $ prettyPrec add_prec e1 . showString " - " . prettyPrec (add_prec+1) e2
SumE es -> showParen (d > add_prec) $ intercalates " + " $ map (prettyPrec add_prec) es MulE e1 e2 -> showParen (d > mul_prec) $ prettyPrec add_prec e1 . showString " * " . prettyPrec (mul_prec+1) e2
ProdE es -> showParen (d > mul_prec) $ intercalates " * " $ map (prettyPrec mul_prec) es ImgE e -> showParen (d > mul_prec) $ showString "X*" . (prettyPrec mul_prec e)
PowE e k -> showParen (d > pow_prec) $ (prettyPrec pow_prec e) . showString ("^" ++ show k)
where where
add_prec = 5 add_prec = 5
mul_prec = 6 mul_prec = 6
pow_prec = 7 -- pow_prec = 7
intercalates sep = go where -- intercalates sep = go where
go [] = id -- go [] = id
go [x] = x -- go [x] = x
go (x:xs) = x . showString sep . go xs -- go (x:xs) = x . showString sep . go xs
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- * Evaluation -- * Evaluation
@ -127,12 +103,11 @@ instance EvalField FExt where fromGoldilocks = fromBase
evalExprWith :: (var -> FExt) -> Expr var -> FExt evalExprWith :: (var -> FExt) -> Expr var -> FExt
evalExprWith evalVar expr = go expr where evalExprWith evalVar expr = go expr where
go e = case e of go e = case e of
VarE v -> evalVar v VarE v -> evalVar v
LitE x -> fromBase x LitE x -> fromBase x
ScaleE s e -> fromBase s * go e AddE e1 e2 -> go e1 + go e2
ImagE e -> (MkExt 0 1) * go e SubE e1 e2 -> go e1 - go e2
SumE es -> sum (map go es) MulE e1 e2 -> go e1 * go e2
ProdE es -> product (map go es) ImgE e -> (MkExt 0 1) * go e
PowE e n -> powExt (go e) (fromIntegral n)
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------

View File

@ -14,6 +14,7 @@ import Data.Word
import Data.Ratio import Data.Ratio
import Data.Array import Data.Array
import Text.Show
import Text.Printf import Text.Printf
import System.Random import System.Random
@ -21,6 +22,8 @@ import System.Random
import GHC.Generics import GHC.Generics
import Data.Aeson ( ToJSON(..), FromJSON(..) ) import Data.Aeson ( ToJSON(..), FromJSON(..) )
import Misc.Pretty
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
type F = Goldilocks type F = Goldilocks
@ -78,6 +81,8 @@ instance Show Goldilocks where
show (Goldilocks x) = show x -- decimal show (Goldilocks x) = show x -- decimal
-- show (Goldilocks x) = printf "0x%016x" x -- hex -- show (Goldilocks x) = printf "0x%016x" x -- hex
instance Pretty Goldilocks where prettyPrec _ x = shows x
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
instance ToJSON Goldilocks where instance ToJSON Goldilocks where

View File

@ -12,10 +12,12 @@ module Algebra.GoldilocksExt where
import Data.Bits import Data.Bits
import Data.Ratio import Data.Ratio
import Text.Show
import Data.Aeson ( ToJSON(..), FromJSON(..) ) import Data.Aeson ( ToJSON(..), FromJSON(..) )
import Algebra.Goldilocks import Algebra.Goldilocks
import Misc.Pretty
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
@ -31,6 +33,12 @@ fromBase x = MkExt x 0
instance Show GoldilocksExt where instance Show GoldilocksExt where
show (MkExt real imag) = "(" ++ show real ++ " + X*" ++ show imag ++ ")" show (MkExt real imag) = "(" ++ show real ++ " + X*" ++ show imag ++ ")"
instance Pretty GoldilocksExt where
prettyPrec d (MkExt real imag)
| imag == 0 = prettyPrec 0 real
| otherwise = showParen (d > 5)
$ prettyPrec 0 real . showString " + X*" . prettyPrec 0 imag
instance ToJSON GoldilocksExt where instance ToJSON GoldilocksExt where
toJSON (MkExt a b) = toJSON (a,b) toJSON (MkExt a b) = toJSON (a,b)
@ -81,3 +89,11 @@ powExt x e
_ -> go (acc*s) (sqrExt s) (shiftR expo 1) _ -> go (acc*s) (sqrExt s) (shiftR expo 1)
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
rndExt :: IO FExt
rndExt = do
x <- rndF
y <- rndF
return (MkExt x y)
--------------------------------------------------------------------------------

160
src/Gate/Computation.hs Normal file
View File

@ -0,0 +1,160 @@
-- | We have to describe the contraints as computations with local definitions
-- Without local definitions the equations would just blow up
-- in the case of the Poseidon gate for example
--
{-# LANGUAGE StrictData, DeriveFunctor, GADTs, RecordWildCards #-}
module Gate.Computation where
--------------------------------------------------------------------------------
import Prelude hiding ( (^) )
import Control.Applicative
import Control.Monad
import Data.Array
import Data.List
import Text.Show
import Data.IntMap (IntMap)
import qualified Data.IntMap as IntMap
import Algebra.Goldilocks
import Algebra.GoldilocksExt
import Algebra.Expr
import Gate.Vars
import Misc.Pretty
--------------------------------------------------------------------------------
-- * Operational monad
data Instr var a where
Let :: String -> Expr var -> Instr var var
Commit :: Expr var -> Instr var ()
data Program instr a where
Bind :: Program instr a -> (a -> Program instr b) -> Program instr b
Return :: a -> Program instr a
Instr :: instr a -> Program instr a
instance Functor (Program instr) where
fmap = liftM
instance Applicative (Program instr) where
(<*>) = ap
pure = Return
instance Monad (Program instr) where
(>>=) = Bind
type Var_ = Var PlonkyVar
type Expr_ = Expr Var_
type Def_ = LocalDef Var_
--type Instr_ a = Instr Var_ a
--------------------------------------------------------------------------------
-- | Our computation monad
type Compute a = Program (Instr Var_) a
let_ :: String -> Expr Var_ -> Compute (Expr Var_)
let_ name rhs = VarE <$> Instr (Let name rhs)
commit :: Expr Var_ -> Compute ()
commit what = Instr (Commit what)
commitList :: [Expr Var_] -> Compute ()
commitList = mapM_ commit
--------------------------------------------------------------------------------
-- | Straightline programs
data LocalDef v
= MkLocalDef Int String (Expr v)
deriving (Eq,Show)
instance Pretty v => Pretty (LocalDef v) where
prettyPrec _ (MkLocalDef k name rhs) = showString ("_" ++ name ++ show k) . showString " := " . prettyPrec 0 rhs
-- | A straightline program encoding the computation of constraints
data StraightLine = MkStraightLine
{ localdefs :: [LocalDef Var_] -- ^ local definitions, in reverse order
, commits :: [Expr_] -- ^ committed constraints, in reverse order
, counter :: Int -- ^ fresh variable counter
}
deriving Show
emptyStraightLine :: StraightLine
emptyStraightLine = MkStraightLine [] [] 0
printStraightLine :: StraightLine -> IO ()
printStraightLine (MkStraightLine{..}) = do
forM_ (reverse localdefs) $ \def -> putStrLn (pretty def)
forM_ (reverse commits ) $ \expr -> putStrLn $ "constraint 0 == " ++ (pretty expr)
compileToStraightLine :: Compute () -> StraightLine
compileToStraightLine = fst . go emptyStraightLine where
go :: StraightLine -> Compute a -> (StraightLine,a)
go state instr = case instr of
Return x -> (state,x)
Bind this rest -> let (state',x) = go state this in go state' (rest x)
Instr this -> case state of
MkStraightLine{..} -> case this of
Commit what -> let state' = MkStraightLine localdefs (what:commits) counter
in (state', ())
Let name rhs -> let def = MkLocalDef counter name rhs
state' = MkStraightLine (def:localdefs) commits (counter+1)
in (state', LocalVar counter name)
--------------------------------------------------------------------------------
type Scope a = IntMap a
emptyScope :: Scope a
emptyScope = IntMap.empty
-- | Run a \"straightline program\", resulting in list of contraints evaluations
runStraightLine :: EvaluationVars FExt -> StraightLine -> [FExt]
runStraightLine = runStraightLine' emptyScope
runStraightLine' :: Scope FExt -> EvaluationVars FExt -> StraightLine -> [FExt]
runStraightLine' iniScope vars (MkStraightLine{..}) = result where
finalScope = foldl' worker iniScope (reverse localdefs)
result = evalConstraints finalScope vars (reverse commits)
worker !scope (MkLocalDef i _ rhs) = IntMap.insert i (evalConstraint scope vars rhs) scope
--------------------------------------------------------------------------------
-- * Evaluation
type Constraint = Expr_
-- | List of all data (one "row") we need to evaluate a gate constraint
--
-- Typically this will be the evaluations of the column polynomials at @zeta@
data EvaluationVars a = MkEvaluationVars
{ local_selectors :: Array Int a -- ^ the selectors
, local_constants :: Array Int a -- ^ the circuit constants
, local_wires :: Array Int a -- ^ the advice wires (witness)
, public_inputs_hash :: [F] -- ^ only used in @PublicInputGate@
}
deriving (Show,Functor)
evalConstraint :: Scope FExt -> EvaluationVars FExt -> Constraint -> FExt
evalConstraint scope (MkEvaluationVars{..}) expr = evalExprWith f expr where
f var = case var of
LocalVar i n -> case IntMap.lookup i scope of
Just y -> y
Nothing -> error $ "variable _" ++ n ++ show i ++ " not in scope"
ProofVar v -> case v of
SelV k -> local_selectors ! k
ConstV k -> local_constants ! k
WireV k -> local_wires ! k
PIV k -> fromBase (public_inputs_hash !! k)
evalConstraints :: Scope FExt -> EvaluationVars FExt -> [Constraint] -> [FExt]
evalConstraints scope vars = map (evalConstraint scope vars)
--------------------------------------------------------------------------------

View File

@ -5,123 +5,92 @@
-- of constraints. -- of constraints.
-- --
{-# LANGUAGE StrictData, RecordWildCards #-} {-# LANGUAGE StrictData, DeriveFunctor, GADTs, RecordWildCards #-}
module Gate.Constraints where module Gate.Constraints where
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
import Data.Array import Prelude hiding ( (^) )
import Data.Char
import Data.Array hiding (range)
import Data.Char
import Text.Show import Text.Show
import Data.IntMap (IntMap)
import qualified Data.IntMap as IntMap
import Algebra.Goldilocks import Algebra.Goldilocks
import Algebra.GoldilocksExt import Algebra.GoldilocksExt
import Algebra.Expr import Algebra.Expr
import Gate.Base import Gate.Base
import Gate.Vars
import Gate.Computation
import Gate.Poseidon
--------------------------------------------------------------------------------
-- * Constraint expressions
-- | These index into a row + public input
data Var
= SelV Int -- ^ selector variable
| ConstV Int -- ^ constant variable
| WireV Int -- ^ wire variable
| PIV Int -- ^ public input hash variable
deriving (Eq,Ord,Show)
instance Pretty Var where
prettyPrec _ v = case v of
SelV k -> showString ("s" ++ show k)
ConstV k -> showString ("c" ++ show k)
WireV k -> showString ("w" ++ show k)
PIV k -> showString ("h" ++ show k)
--------------------------------------------------------------------------------
-- | List of all data (one "row") we need to evaluate a gate constraint
--
-- Typically this will be the evaluations of the column polynomials at @zeta@
data EvaluationVars a = MkEvaluationVars
{ local_selectors :: Array Int a -- ^ the selectors
, local_constants :: Array Int a -- ^ the circuit constants
, local_wires :: Array Int a -- ^ the advice wires (witness)
, public_inputs_hash :: [F] -- ^ only used in @PublicInputGate@
}
deriving (Show)
--------------------------------------------------------------------------------
-- * Evaluation
evalExpr :: Expr Var -> EvaluationVars FExt -> FExt
evalExpr expr (MkEvaluationVars{..}) = evalExprWith f expr where
f v = case v of
SelV k -> local_selectors ! k
ConstV k -> local_constants ! k
WireV k -> local_wires ! k
PIV k -> fromBase (public_inputs_hash !! k)
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- * Gate constraints -- * Gate constraints
-- | Returns the (symbolic) constraints for the given gate -- | Returns the (symbolic) constraints for the given gate
gateConstraints :: Gate -> [Expr Var] --
gateConstraints gate = gateProgram :: Gate -> StraightLine
gateProgram = compileToStraightLine . gateComputation
gateComputation :: Gate -> Compute ()
gateComputation gate =
case gate of case gate of
-- `w[i] - c0*x[i]*y[i] - c1*z[i] = 0` -- `w[i] - c0*x[i]*y[i] - c1*z[i] = 0`
ArithmeticGate num_ops ArithmeticGate num_ops
-> [ ww (j+3) - cc 0 * ww j * ww (j+1) - cc 1 * ww (j+2) | i<-range num_ops, let j = 4*i ] -> commitList [ wire (j+3) - cnst 0 * wire j * wire (j+1) - cnst 1 * wire (j+2) | i<-range num_ops, let j = 4*i ]
-- same but consecutive witness variables make up an extension field element -- same but consecutive witness variables make up an extension field element
ArithmeticExtensionGate num_ops ArithmeticExtensionGate num_ops
-> [ wwExt (j+6) - cc 0 * wwExt j * wwExt (j+2) - cc 1 * wwExt (j+4) | i<-range num_ops, let j = 8*i ] -> commitList [ wireExt (j+6) - cnst 0 * wireExt j * wireExt (j+2) - cnst 1 * wireExt (j+4) | i<-range num_ops, let j = 8*i ]
-- `sum b^i * limbs[i] - out = 0`, and `0 <= limb[i] < B` is enforced -- `sum b^i * limbs[i] - out = 0`, and `0 <= limb[i] < B` is enforced
BaseSumGate num_limbs base BaseSumGate num_limbs base
-> let limb i = ww (i+1) -> let limb i = wire (i+1)
horner = go 0 where go k = if k < num_limbs-1 then limb k + fromIntegral base * go (k+1) else limb k horner = go 0 where go k = if k < num_limbs-1 then limb k + fromIntegral base * go (k+1) else limb k
sum_eq = horner - ww 0 sum_eq = horner - wire 0
range_eq i = ProdE [ limb i - fromIntegral k | k<-[0..base-1] ] range_eq i = product [ limb i - fromIntegral k | k<-[0..base-1] ]
in sum_eq : [ range_eq i | i<-range num_limbs ] in commitList $ sum_eq : [ range_eq i | i<-range num_limbs ]
CosetInterpolationGate subgroup_bits coset_degree barycentric_weights CosetInterpolationGate subgroup_bits coset_degree barycentric_weights
-> todo -> todo
-- `c[i] - x[i] = 0` -- `c[i] - x[i] = 0`
ConstantGate num_consts ConstantGate num_consts
-> [ cc i - ww i | i <- range num_consts ] -> commitList [ cnst i - wire i | i <- range num_consts ]
-- computes `out = base ^ (sum 2^i e_i)` -- computes `out = base ^ (sum 2^i e_i)`
-- order of witness variables: [ base, e[0],...,e[n-1], output, t[0]...t[n-1] ] -- order of witness variables: [ base, e[0],...,e[n-1], output, t[0]...t[n-1] ]
ExponentiationGate num_power_bits ExponentiationGate num_power_bits
-> let base = ww 0 -> let base = wire 0
exp_bit i = ww (i+1) exp_bit i = wire (i+1)
out = ww (num_power_bits+1) out = wire (num_power_bits+1)
tmp_val 0 = 1 tmp_val 0 = 1
tmp_val i = ww (num_power_bits+1+i) tmp_val i = wire (num_power_bits+1+i)
cur_bit i = exp_bit (num_power_bits - 1 - i) cur_bit i = exp_bit (num_power_bits - 1 - i)
eq i = tmp_val (i-1) * (cur_bit i * base + 1 - cur_bit i) - tmp_val i eq i = tmp_val (i-1) * (cur_bit i * base + 1 - cur_bit i) - tmp_val i
in [ eq i | i <- range num_power_bits ] ++ [ out - tmp_val (num_power_bits-1) ] in commitList $ [ eq i | i <- range num_power_bits ] ++ [ out - tmp_val (num_power_bits-1) ]
-- lookups are handled specially, no constraints here -- lookups are handled specially, no constraints here
LookupGate num_slots lut_hash -> [] LookupGate num_slots lut_hash -> return ()
LookupTableGate num_slots lut_hash last_lut_row -> [] LookupTableGate num_slots lut_hash last_lut_row -> return ()
-- `z[i] - c0*x[i]*y[i] = 0`, and two witness cells make up an extension field element -- `z[i] - c0*x[i]*y[i] = 0`, and two witness cells make up an extension field element
MulExtensionGate num_ops MulExtensionGate num_ops
-> [ wwExt (j+4) - cc 0 * wwExt j * wwExt (j+2) | i<-range num_ops, let j = 6*i ] -> commitList [ wireExt (j+4) - cnst 0 * wireExt j * wireExt (j+2) | i<-range num_ops, let j = 6*i ]
NoopGate -> [] NoopGate -> return ()
-- equality with "hardcoded" hash components -- equality with "hardcoded" hash components
PublicInputGate PublicInputGate
-> [ hh i - ww i | i <- range 4 ] -> commitList [ hash i - wire i | i <- range 4 ]
PoseidonGate hash_width -> case hash_width of PoseidonGate hash_width -> case hash_width of
12 -> todo -- poseidonGateConstraints 12 -> poseidonGateConstraints
k -> error ( "gateConstraints/PoseidonGate: unsupported width " ++ show k) k -> error ( "gateConstraints/PoseidonGate: unsupported width " ++ show k)
PoseidonMdsGate hash_width -> case hash_width of PoseidonMdsGate hash_width -> case hash_width of
@ -143,12 +112,4 @@ gateConstraints gate =
todo = error $ "gateConstraints: gate `" ++ takeWhile isAlpha (show gate) ++ "` not yet implemented" todo = error $ "gateConstraints: gate `" ++ takeWhile isAlpha (show gate) ++ "` not yet implemented"
range k = [0..k-1] --------------------------------------------------------------------------------
ww i = VarE (WireV i) -- witness variable
cc i = VarE (ConstV i) -- constant variable
hh i = VarE (PIV i) -- public input hash component
wwExt i = ww i + ImagE (ww (i+1)) -- use two consecutive variables as an extension field element
--------------------------------------------------------------------------------

View File

@ -1,5 +1,6 @@
-- | Gates are encoded as strings produced by ad-hoc of modification of Rust textual serialization... -- | Gates are encoded as strings produced by some ad-hoc modifications
-- of the default Rust textual serialization...
-- --
-- ... so we have to parse /that/ -- ... so we have to parse /that/
-- --

134
src/Gate/Poseidon.hs Normal file
View File

@ -0,0 +1,134 @@
-- | Plonky2's Poseidon gate
--
{-# LANGUAGE StrictData, RecordWildCards #-}
module Gate.Poseidon where
--------------------------------------------------------------------------------
import Data.Array hiding (range)
import Data.Char
import Data.Foldable
import Control.Monad
import Control.Monad.State.Strict
import Algebra.Goldilocks
import Algebra.Expr
import Gate.Vars
import Gate.Computation
import Hash.Constants
--------------------------------------------------------------------------------
sbox :: Expr Var_ -> Compute (Expr Var_)
sbox x = do
x2 <- let_ "x2_" (x *x )
x3 <- let_ "x3_" (x *x2)
x4 <- let_ "x4_" (x2*x2)
return (x3*x4)
flipFoldM :: (Foldable t, Monad m) => b -> t a -> (b -> a -> m b) -> m b
flipFoldM s0 list action = foldM action s0 list
flipFoldM_ :: (Foldable t, Monad m) => b -> t a -> (b -> a -> m b) -> m ()
flipFoldM_ s0 list action = void (foldM action s0 list)
-- | Poseidon state
type PS = [Expr Var_]
--------------------------------------------------------------------------------
poseidonGateConstraints :: Compute ()
poseidonGateConstraints = do
-- merkle swap
let input_lhs i = input i
let input_rhs i = input (i+4)
commit $ (1 - swap_flag) * swap_flag
commitList [ swap_flag * (input_rhs i - input_lhs i) - delta i | i <- range 4 ]
-- swapped inputs
let state0 :: PS
state0
= [ input_lhs i + delta i | i <- [0.. 3] ]
++ [ input_rhs (i-4) - delta (i-4) | i <- [4.. 7] ]
++ [ input i | i <- [8..11] ]
-- initial full rounds
phase1 <- flipFoldM state0 [0..3] $ \state1 r -> do
let state2 = plus_rc r state1
state3 <- if r == 0
then return state2
else do
let sbox_in = initial_sbox_in r
commitList [ (state2!!i) - sbox_in i | i <- range 12 ]
return [ sbox_in i | i <- range 12 ]
state4 <- mapM sbox state3
return $ mds state4
-- partial rounds
let state' = zipWith (+) phase1 (map LitE $ elems fast_PARTIAL_FIRST_ROUND_CONSTANT)
let state'' = mdsInitPartial state'
phase2 <- flipFoldM state'' [0..21] $ \state1 r -> do
let sbox_in = partial_sbox_in r
commit $ (state1!!0) - sbox_in
y <- sbox sbox_in
let z = if r < 21 then y + LitE (fast_PARTIAL_ROUND_CONSTANTS!r) else y
let state2 = z : tail state1
return $ mdsFastPartial r state2
-- final full rounds
phase3 <- flipFoldM phase2 [0..3] $ \state1 r -> do
let state2 = plus_rc (r+26) state1
let sbox_in = final_sbox_in r
commitList [ (state2!!i) - sbox_in i | i <- range 12 ]
state3 <- mapM sbox [ sbox_in i | i <- range 12 ]
return $ mds state3
-- constraint the output to be the result
commitList [ phase3!!i - output i | i <- range 12 ]
where
-- multiply by the MDS matrix
mds :: PS -> PS
mds state =
[ sum [ LitE (mdsMatrixCoeff i j) * x | (j,x) <- zip [0..] state ]
| i <- range 12
]
dotProd :: PS -> [F] -> Expr Var_
dotProd es cs = sum $ zipWith (\e c -> e * LitE c) es cs
mdsInitPartial :: PS -> PS
mdsInitPartial (first:rest)
= first
: [ sum [ LitE (partialMdsMatrixCoeff i j) * x | (j,x) <- zip [0..] rest ]
| i <- range 11
]
mdsFastPartial :: Int -> PS -> PS
mdsFastPartial r state@(s0:rest) = res where
m0 = mdsMatrixCoeff 0 0
cs = m0 : elems (fast_PARTIAL_ROUND_W_HATS!r)
d = dotProd state cs
res = d : [ x + s0 * LitE t | (x,t) <- zip rest (elems $ fast_PARTIAL_ROUND_VS!r) ]
-- add round constants
plus_rc :: Int -> PS -> PS
plus_rc r state = zipWith (+) state (map LitE $ elems (all_ROUND_CONSTANTS!r))
-- witness variables
input i = wire i
output i = wire (i+12)
swap_flag = wire 24
delta i = wire (25+i)
initial_sbox_in r i = wire (29 + 12*(r-1) + i) -- 0 < r < 4 , 0 <= i < 12
partial_sbox_in r = wire (29 + 36 + r) -- 0 <= r < 22
final_sbox_in r i = wire (29 + 36 + 22 + 12*r + i) -- 0 <= r < 4 , 0 <= i < 12
--------------------------------------------------------------------------------

58
src/Gate/Vars.hs Normal file
View File

@ -0,0 +1,58 @@
-- | Variables appearing in gate constraints
module Gate.Vars where
--------------------------------------------------------------------------------
import Text.Show
import Algebra.Expr
import Misc.Pretty
--------------------------------------------------------------------------------
-- * Constraint variables
-- | These index into a row + public input
data PlonkyVar
= SelV Int -- ^ selector variable
| ConstV Int -- ^ constant variable
| WireV Int -- ^ wire variable
| PIV Int -- ^ public input hash variable
deriving (Eq,Ord,Show)
instance Pretty PlonkyVar where
prettyPrec _ v = case v of
SelV k -> showString ("s" ++ show k)
ConstV k -> showString ("c" ++ show k)
WireV k -> showString ("w" ++ show k)
PIV k -> showString ("h" ++ show k)
--------------------------------------------------------------------------------
-- * Variables
data Var v
= LocalVar Int String -- ^ a temporary variable
| ProofVar v -- ^ a proof variable (eg. witness, constant, selector)
deriving (Eq,Show)
instance Pretty v => Pretty (Var v) where
prettyPrec d var = case var of
LocalVar k name -> showString ("_" ++ name ++ show k)
ProofVar v -> prettyPrec d v
--------------------------------------------------------------------------------
-- * Convenience
range :: Int -> [Int]
range k = [0..k-1]
wire, cnst, hash :: Int -> Expr (Var PlonkyVar)
wire i = VarE $ ProofVar $ WireV i -- witness variable
cnst i = VarE $ ProofVar $ ConstV i -- constant variable
hash i = VarE $ ProofVar $ PIV i -- public input hash component
wireExt :: Int -> Expr (Var PlonkyVar)
wireExt i = wire i + ImgE (wire (i+1)) -- use two consecutive variables as an extension field element
--------------------------------------------------------------------------------

View File

@ -40,10 +40,8 @@ fast_PARTIAL_ROUND_CONSTANTS = listArray (0,21)
, 0x1aca78f31c97c876 , 0x0 , 0x1aca78f31c97c876 , 0x0
] ]
{- fast_PARTIAL_ROUND_VS :: Array Int (Array Int F)
fast_PARTIAL_ROUND_VS = listArray (0,21) $ map (listArray (0,10))
fast_PARTIAL_ROUND_VS :: [Array Int F]
fast_PARTIAL_ROUND_VS = map (listArray (0,10))
[ [0x94877900674181c3, 0xc6c67cc37a2a2bbd, 0xd667c2055387940f, 0x0ba63a63e94b5ff0, 0x99460cc41b8f079f, 0x7ff02375ed524bb3, 0xea0870b47a8caf0e, 0xabcad82633b7bc9d, 0x3b8d135261052241, 0xfb4515f5e5b0d539, 0x3ee8011c2b37f77c ] [ [0x94877900674181c3, 0xc6c67cc37a2a2bbd, 0xd667c2055387940f, 0x0ba63a63e94b5ff0, 0x99460cc41b8f079f, 0x7ff02375ed524bb3, 0xea0870b47a8caf0e, 0xabcad82633b7bc9d, 0x3b8d135261052241, 0xfb4515f5e5b0d539, 0x3ee8011c2b37f77c ]
, [0x0adef3740e71c726, 0xa37bf67c6f986559, 0xc6b16f7ed4fa1b00, 0x6a065da88d8bfc3c, 0x4cabc0916844b46f, 0x407faac0f02e78d1, 0x07a786d9cf0852cf, 0x42433fb6949a629a, 0x891682a147ce43b0, 0x26cfd58e7b003b55, 0x2bbf0ed7b657acb3 ] , [0x0adef3740e71c726, 0xa37bf67c6f986559, 0xc6b16f7ed4fa1b00, 0x6a065da88d8bfc3c, 0x4cabc0916844b46f, 0x407faac0f02e78d1, 0x07a786d9cf0852cf, 0x42433fb6949a629a, 0x891682a147ce43b0, 0x26cfd58e7b003b55, 0x2bbf0ed7b657acb3 ]
, [0x481ac7746b159c67, 0xe367de32f108e278, 0x73f260087ad28bec, 0x5cfc82216bc1bdca, 0xcaccc870a2663a0e, 0xdb69cd7b4298c45d, 0x7bc9e0c57243e62d, 0x3cc51c5d368693ae, 0x366b4e8cc068895b, 0x2bd18715cdabbca4, 0xa752061c4f33b8cf ] , [0x481ac7746b159c67, 0xe367de32f108e278, 0x73f260087ad28bec, 0x5cfc82216bc1bdca, 0xcaccc870a2663a0e, 0xdb69cd7b4298c45d, 0x7bc9e0c57243e62d, 0x3cc51c5d368693ae, 0x366b4e8cc068895b, 0x2bd18715cdabbca4, 0xa752061c4f33b8cf ]
@ -68,8 +66,8 @@ fast_PARTIAL_ROUND_VS = map (listArray (0,10))
, [0x0000000000000014, 0x0000000000000022, 0x0000000000000012, 0x0000000000000027, 0x000000000000000d, 0x000000000000000d, 0x000000000000001c, 0x0000000000000002, 0x0000000000000010, 0x0000000000000029, 0x000000000000000f ] , [0x0000000000000014, 0x0000000000000022, 0x0000000000000012, 0x0000000000000027, 0x000000000000000d, 0x000000000000000d, 0x000000000000001c, 0x0000000000000002, 0x0000000000000010, 0x0000000000000029, 0x000000000000000f ]
] ]
fast_PARTIAL_ROUND_W_HATS :: [Array Int F] fast_PARTIAL_ROUND_W_HATS :: Array Int (Array Int F)
fast_PARTIAL_ROUND_W_HATS = map (listArray (0,10)) fast_PARTIAL_ROUND_W_HATS = listArray (0,21) $ map (listArray (0,10))
[ [0x3d999c961b7c63b0, 0x814e82efcd172529, 0x2421e5d236704588, 0x887af7d4dd482328, 0xa5e9c291f6119b27, 0xbdc52b2676a4b4aa, 0x64832009d29bcf57, 0x09c4155174a552cc, 0x463f9ee03d290810, 0xc810936e64982542, 0x043b1c289f7bc3ac ] [ [0x3d999c961b7c63b0, 0x814e82efcd172529, 0x2421e5d236704588, 0x887af7d4dd482328, 0xa5e9c291f6119b27, 0xbdc52b2676a4b4aa, 0x64832009d29bcf57, 0x09c4155174a552cc, 0x463f9ee03d290810, 0xc810936e64982542, 0x043b1c289f7bc3ac ]
, [0x673655aae8be5a8b, 0xd510fe714f39fa10, 0x2c68a099b51c9e73, 0xa667bfa9aa96999d, 0x4d67e72f063e2108, 0xf84dde3e6acda179, 0x40f9cc8c08f80981, 0x5ead032050097142, 0x6591b02092d671bb, 0x00e18c71963dd1b7, 0x8a21bcd24a14218a ] , [0x673655aae8be5a8b, 0xd510fe714f39fa10, 0x2c68a099b51c9e73, 0xa667bfa9aa96999d, 0x4d67e72f063e2108, 0xf84dde3e6acda179, 0x40f9cc8c08f80981, 0x5ead032050097142, 0x6591b02092d671bb, 0x00e18c71963dd1b7, 0x8a21bcd24a14218a ]
, [0x202800f4addbdc87, 0xe4b5bdb1cc3504ff, 0xbe32b32a825596e7, 0x8e0f68c5dc223b9a, 0x58022d9e1c256ce3, 0x584d29227aa073ac, 0x8b9352ad04bef9e7, 0xaead42a3f445ecbf, 0x3c667a1d833a3cca, 0xda6f61838efa1ffe, 0xe8f749470bd7c446 ] , [0x202800f4addbdc87, 0xe4b5bdb1cc3504ff, 0xbe32b32a825596e7, 0x8e0f68c5dc223b9a, 0x58022d9e1c256ce3, 0x584d29227aa073ac, 0x8b9352ad04bef9e7, 0xaead42a3f445ecbf, 0x3c667a1d833a3cca, 0xda6f61838efa1ffe, 0xe8f749470bd7c446 ]
@ -94,8 +92,6 @@ fast_PARTIAL_ROUND_W_HATS = map (listArray (0,10))
, [0x3abeb80def61cc85, 0x9d19c9dd4eac4133, 0x075a652d9641a985, 0x9daf69ae1b67e667, 0x364f71da77920a18, 0x50bd769f745c95b1, 0xf223d1180dbbf3fc, 0x2f885e584e04aa99, 0xb69a0fa70aea684a, 0x09584acaa6e062a0, 0x0bc051640145b19b ] , [0x3abeb80def61cc85, 0x9d19c9dd4eac4133, 0x075a652d9641a985, 0x9daf69ae1b67e667, 0x364f71da77920a18, 0x50bd769f745c95b1, 0xf223d1180dbbf3fc, 0x2f885e584e04aa99, 0xb69a0fa70aea684a, 0x09584acaa6e062a0, 0x0bc051640145b19b ]
] ]
-}
-- ^ NB: This is in ROW-major order to support cache-friendly pre-multiplication. -- ^ NB: This is in ROW-major order to support cache-friendly pre-multiplication.
fast_PARTIAL_ROUND_INITIAL_MATRIX :: Array (Int,Int) F fast_PARTIAL_ROUND_INITIAL_MATRIX :: Array (Int,Int) F
fast_PARTIAL_ROUND_INITIAL_MATRIX = listArray ((0,0),(10,10)) $ concat fast_PARTIAL_ROUND_INITIAL_MATRIX = listArray ((0,0),(10,10)) $ concat
@ -112,6 +108,9 @@ fast_PARTIAL_ROUND_INITIAL_MATRIX = listArray ((0,0),(10,10)) $ concat
, [0xd841e8ef9dde8ba0, 0x156048ee7a738154, 0x85418a9fef8a9890, 0x64dd936da878404d, 0x726af914971c1374, 0x7f8e41e0b0a6cdff, 0xf97abba0dffb6c50, 0xf4a437f2888ae909, 0xdcedab70f40718ba, 0xe796d293a47a64cb, 0x80772dc2645b280b ] , [0xd841e8ef9dde8ba0, 0x156048ee7a738154, 0x85418a9fef8a9890, 0x64dd936da878404d, 0x726af914971c1374, 0x7f8e41e0b0a6cdff, 0xf97abba0dffb6c50, 0xf4a437f2888ae909, 0xdcedab70f40718ba, 0xe796d293a47a64cb, 0x80772dc2645b280b ]
] ]
partialMdsMatrixCoeff :: Int -> Int -> F
partialMdsMatrixCoeff i j = fast_PARTIAL_ROUND_INITIAL_MATRIX ! (j,i)
partition12 :: [a] -> [[a]] partition12 :: [a] -> [[a]]
partition12 = go where partition12 = go where
go [] = [] go [] = []

15
src/Misc/Pretty.hs Normal file
View File

@ -0,0 +1,15 @@
-- | Precedence-aware pretty-printing
module Misc.Pretty where
--------------------------------------------------------------------------------
-- | See "Text.Show"
class Pretty a where
prettyPrec :: Int -> a -> (String -> String)
pretty :: Pretty a => a -> String
pretty x = prettyPrec 0 x ""
--------------------------------------------------------------------------------

82
src/Test/Witness.hs Normal file
View File

@ -0,0 +1,82 @@
-- | We can test the gate constraints on actual witness rows during development
{-# LANGUAGE StrictData, DeriveGeneric, DeriveAnyClass, RecordWildCards #-}
module Test.Witness where
--------------------------------------------------------------------------------
import Data.Array
import Data.Aeson
import GHC.Generics
import qualified Data.ByteString.Lazy as L
import Algebra.Goldilocks
import Algebra.GoldilocksExt
import Gate.Base
import Gate.Computation
import Gate.Constraints
--------------------------------------------------------------------------------
-- | As exported into JSON by our Plonky2 fork
data Witness = MkWitness
{ gates :: [String]
, selector_vector :: [Int]
, selector_columns :: [[F]]
, constants_columns :: [[F]]
, matrix :: [[F]]
}
deriving (Show,Generic,ToJSON,FromJSON)
matrixRow :: [[F]] -> Int -> [F]
matrixRow matrix i = map (!!i) matrix
witnessRow :: Witness -> Int -> EvaluationVars F
witnessRow (MkWitness{..}) i = MkEvaluationVars
{ local_selectors = toArray (matrixRow selector_columns i)
, local_constants = toArray (matrixRow constants_columns i)
, local_wires = toArray (matrixRow matrix i)
, public_inputs_hash = []
}
where
toArray xs = listArray (0, length xs - 1) xs
loadWitness :: FilePath -> IO Witness
loadWitness fpath = do
txt <- L.readFile fpath
case decode txt of
Just w -> return w
Nothing -> fail "loadWitness: cannot parse JSON witness"
--------------------------------------------------------------------------------
test_fibonacci = do
witness <- loadWitness "../json/fibonacci_witness.json"
let arith_row = witnessRow witness 0
let posei_row = witnessRow witness 5
let pubio_row = witnessRow witness 6
let const_row = witnessRow witness 7
let arith_prg = gateProgram (ArithmeticGate 20)
let posei_prg = gateProgram (PoseidonGate 12)
let pubio_prg = gateProgram (PublicInputGate )
let const_prg = gateProgram (ConstantGate 2)
let arith_evals = runStraightLine (fmap fromBase arith_row) arith_prg
let const_evals = runStraightLine (fmap fromBase const_row) const_prg
let posei_evals = runStraightLine (fmap fromBase posei_row) posei_prg
putStrLn $ "number of constants in ArithmeticGate = " ++ show (length arith_evals)
putStrLn $ "number of constants in ConstantGate = " ++ show (length const_evals)
putStrLn $ "number of constants in PosiedoGate = " ++ show (length posei_evals)
print arith_evals
print const_evals
print posei_evals
--------------------------------------------------------------------------------