completely rewrite the constraints; PoseidonGate seems to work now

This commit is contained in:
Balazs Komuves 2024-12-15 14:03:44 +01:00
parent 338163f56d
commit c949f3d3f2
No known key found for this signature in database
GPG Key ID: F63B7AEF18435562
12 changed files with 563 additions and 157 deletions

View File

@ -2,7 +2,7 @@
A standalone Plonky2 verifier
-----------------------------
This is a (WIP) implementation of a Plonky2 verifier in Haskell.
This is a (WIP) implementation of a Plonky2 verifier written in Haskell.
[Plonky2](https://github.com/0xPolygonZero/plonky2/) is a zero-knowledge proof
system developed by Polygon Zero, optimized for recursive proofs.
@ -15,7 +15,7 @@ Another goal is to be a basis for further tooling (for example:
estimating verifier costs, helping the design of recursive circuits, generating
Plonky2 verifier circuits for other proof systems, etc)
Note: It's deliberately no a goal for this verifier to be efficient; instead we
Note: It's deliberately not a goal for this verifier to be efficient; instead we
try to focus on simplicity.
@ -42,7 +42,7 @@ Supported gates:
- [x] MulExtensionGate
- [x] NoopGate
- [x] PublicInputGate
- [ ] PoseidonGate
- [x] PoseidonGate
- [ ] PoseidonMdsGate
- [ ] RandomAccessGate
- [ ] ReducingGate
@ -52,6 +52,6 @@ Optional features:
- [ ] Field extensions with degree higher than 2
- [ ] Being parametric over the field choice
- [ ] Supporting different hash function choices
- [ ] Supporting different hash functions

View File

@ -7,6 +7,8 @@ module Algebra.Expr where
--------------------------------------------------------------------------------
import Prelude hiding ( (^) )
import Data.Array
import Data.Char
@ -16,102 +18,76 @@ import Algebra.Goldilocks
import Algebra.GoldilocksExt
import Gate.Base
import Misc.Pretty
--------------------------------------------------------------------------------
-- * Polynomial expressions
data Expr v
= VarE v -- ^ a variable
| LitE F -- ^ constant literal
| ScaleE F (Expr v) -- ^ linear scaling by a constant
| ImagE (Expr v) -- ^ multiplies by the field extension generator X
| SumE [Expr v] -- ^ sum of expressions
| ProdE [Expr v] -- ^ product of expressions
| PowE (Expr v) Int -- ^ exponentiation
deriving (Eq) -- ,Show)
= VarE v -- ^ a variable
| LitE F -- ^ constant literal
| AddE (Expr v) (Expr v) -- ^ addition
| SubE (Expr v) (Expr v) -- ^ subtraction
| MulE (Expr v) (Expr v) -- ^ multiplication
| ImgE (Expr v) -- ^ multiplies by the field extension generator X
deriving (Eq,Show)
instance Pretty var => Show (Expr var) where show = pretty
-- instance Pretty var => Show (Expr var) where show = pretty
-- | Degree of the expression
exprDegree :: Expr var -> Int
exprDegree = go where
go expr = case expr of
VarE _ -> 1
LitE _ -> 0
ScaleE _ e -> go e
ImagE e -> go e
SumE es -> if null es then 0 else maximum (map go es)
ProdE es -> sum (map go es)
PowE e n -> n * go e
VarE _ -> 1
LitE _ -> 0
AddE e1 e2 -> max (go e1) (go e2)
SubE e1 e2 -> max (go e1) (go e2)
MulE e1 e2 -> go e1 + go e2
ImgE e -> go e
instance Num (Expr var) where
fromInteger = LitE . fromInteger
negate = negE
(+) = addE
(-) = subE
(*) = mulE
(+) = AddE
(-) = SubE
(*) = MulE
abs = error "Expr/abs"
signum = error "Expr/signum"
negE :: Expr var -> Expr var
negE (ScaleE s e) = ScaleE (negate s) e
negE e = ScaleE (-1) e
negE e = SubE (LitE 0) e
{-
(^) :: Expr var -> Int -> Expr var
(^) = PowE
addE :: Expr var -> Expr var -> Expr var
addE (SumE es) (SumE fs) = SumE (es++fs )
addE e (SumE fs) = SumE (e : fs )
addE (SumE es) f = SumE (es++[f])
addE e f = SumE [e,f]
subE :: Expr var -> Expr var -> Expr var
subE e f = addE e (negate f)
sclE :: F -> Expr var -> Expr var
sclE s (ScaleE t e) = sclE (s*t) e
sclE s e = ScaleE s e
addE = AddE
mulE :: Expr var -> Expr var -> Expr var
mulE (ScaleE s e) (ScaleE t f) = sclE (s*t) (mulE e f)
mulE (ScaleE s e) f = sclE s (mulE e f)
mulE (LitE s) f = sclE s f
mulE e (LitE t) = sclE t e
mulE e (ScaleE t f) = sclE t (mulE e f)
mulE (ProdE es) (ProdE fs) = ProdE (es++fs )
mulE e (ProdE fs) = ProdE (e : fs )
mulE (ProdE es) f = ProdE (es++[f])
mulE e f = ProdE [e,f]
mulE = MulE
-}
--------------------------------------------------------------------------------
-- * pretty printing
-- | TODO: maybe move this somewhere else
class Pretty a where
prettyPrec :: Int -> a -> (String -> String)
pretty :: Pretty a => a -> String
pretty x = prettyPrec 0 x ""
instance Pretty F where prettyPrec _ x = shows x
instance Pretty FExt where prettyPrec _ x = shows x
instance Pretty var => Pretty (Expr var) where
prettyPrec d expr =
case expr of
VarE v -> prettyPrec 0 v
LitE x -> prettyPrec 0 x
ScaleE s e -> prettyPrec 0 s . showString " * " . showParen (d > mul_prec) (prettyPrec mul_prec e)
ImagE e -> showString "X*" . showParen (d > mul_prec) (prettyPrec mul_prec e)
SumE es -> showParen (d > add_prec) $ intercalates " + " $ map (prettyPrec add_prec) es
ProdE es -> showParen (d > mul_prec) $ intercalates " * " $ map (prettyPrec mul_prec) es
PowE e k -> showParen (d > pow_prec) $ (prettyPrec pow_prec e) . showString ("^" ++ show k)
VarE v -> prettyPrec 0 v
LitE x -> prettyPrec 0 x
AddE e1 e2 -> showParen (d > add_prec) $ prettyPrec add_prec e1 . showString " + " . prettyPrec (add_prec+1) e2
SubE e1 e2 -> showParen (d > add_prec) $ prettyPrec add_prec e1 . showString " - " . prettyPrec (add_prec+1) e2
MulE e1 e2 -> showParen (d > mul_prec) $ prettyPrec add_prec e1 . showString " * " . prettyPrec (mul_prec+1) e2
ImgE e -> showParen (d > mul_prec) $ showString "X*" . (prettyPrec mul_prec e)
where
add_prec = 5
mul_prec = 6
pow_prec = 7
intercalates sep = go where
go [] = id
go [x] = x
go (x:xs) = x . showString sep . go xs
-- pow_prec = 7
-- intercalates sep = go where
-- go [] = id
-- go [x] = x
-- go (x:xs) = x . showString sep . go xs
--------------------------------------------------------------------------------
-- * Evaluation
@ -127,12 +103,11 @@ instance EvalField FExt where fromGoldilocks = fromBase
evalExprWith :: (var -> FExt) -> Expr var -> FExt
evalExprWith evalVar expr = go expr where
go e = case e of
VarE v -> evalVar v
LitE x -> fromBase x
ScaleE s e -> fromBase s * go e
ImagE e -> (MkExt 0 1) * go e
SumE es -> sum (map go es)
ProdE es -> product (map go es)
PowE e n -> powExt (go e) (fromIntegral n)
VarE v -> evalVar v
LitE x -> fromBase x
AddE e1 e2 -> go e1 + go e2
SubE e1 e2 -> go e1 - go e2
MulE e1 e2 -> go e1 * go e2
ImgE e -> (MkExt 0 1) * go e
--------------------------------------------------------------------------------

View File

@ -14,6 +14,7 @@ import Data.Word
import Data.Ratio
import Data.Array
import Text.Show
import Text.Printf
import System.Random
@ -21,6 +22,8 @@ import System.Random
import GHC.Generics
import Data.Aeson ( ToJSON(..), FromJSON(..) )
import Misc.Pretty
--------------------------------------------------------------------------------
type F = Goldilocks
@ -78,6 +81,8 @@ instance Show Goldilocks where
show (Goldilocks x) = show x -- decimal
-- show (Goldilocks x) = printf "0x%016x" x -- hex
instance Pretty Goldilocks where prettyPrec _ x = shows x
--------------------------------------------------------------------------------
instance ToJSON Goldilocks where

View File

@ -12,10 +12,12 @@ module Algebra.GoldilocksExt where
import Data.Bits
import Data.Ratio
import Text.Show
import Data.Aeson ( ToJSON(..), FromJSON(..) )
import Algebra.Goldilocks
import Misc.Pretty
--------------------------------------------------------------------------------
@ -31,6 +33,12 @@ fromBase x = MkExt x 0
instance Show GoldilocksExt where
show (MkExt real imag) = "(" ++ show real ++ " + X*" ++ show imag ++ ")"
instance Pretty GoldilocksExt where
prettyPrec d (MkExt real imag)
| imag == 0 = prettyPrec 0 real
| otherwise = showParen (d > 5)
$ prettyPrec 0 real . showString " + X*" . prettyPrec 0 imag
instance ToJSON GoldilocksExt where
toJSON (MkExt a b) = toJSON (a,b)
@ -81,3 +89,11 @@ powExt x e
_ -> go (acc*s) (sqrExt s) (shiftR expo 1)
--------------------------------------------------------------------------------
rndExt :: IO FExt
rndExt = do
x <- rndF
y <- rndF
return (MkExt x y)
--------------------------------------------------------------------------------

160
src/Gate/Computation.hs Normal file
View File

@ -0,0 +1,160 @@
-- | We have to describe the contraints as computations with local definitions
-- Without local definitions the equations would just blow up
-- in the case of the Poseidon gate for example
--
{-# LANGUAGE StrictData, DeriveFunctor, GADTs, RecordWildCards #-}
module Gate.Computation where
--------------------------------------------------------------------------------
import Prelude hiding ( (^) )
import Control.Applicative
import Control.Monad
import Data.Array
import Data.List
import Text.Show
import Data.IntMap (IntMap)
import qualified Data.IntMap as IntMap
import Algebra.Goldilocks
import Algebra.GoldilocksExt
import Algebra.Expr
import Gate.Vars
import Misc.Pretty
--------------------------------------------------------------------------------
-- * Operational monad
data Instr var a where
Let :: String -> Expr var -> Instr var var
Commit :: Expr var -> Instr var ()
data Program instr a where
Bind :: Program instr a -> (a -> Program instr b) -> Program instr b
Return :: a -> Program instr a
Instr :: instr a -> Program instr a
instance Functor (Program instr) where
fmap = liftM
instance Applicative (Program instr) where
(<*>) = ap
pure = Return
instance Monad (Program instr) where
(>>=) = Bind
type Var_ = Var PlonkyVar
type Expr_ = Expr Var_
type Def_ = LocalDef Var_
--type Instr_ a = Instr Var_ a
--------------------------------------------------------------------------------
-- | Our computation monad
type Compute a = Program (Instr Var_) a
let_ :: String -> Expr Var_ -> Compute (Expr Var_)
let_ name rhs = VarE <$> Instr (Let name rhs)
commit :: Expr Var_ -> Compute ()
commit what = Instr (Commit what)
commitList :: [Expr Var_] -> Compute ()
commitList = mapM_ commit
--------------------------------------------------------------------------------
-- | Straightline programs
data LocalDef v
= MkLocalDef Int String (Expr v)
deriving (Eq,Show)
instance Pretty v => Pretty (LocalDef v) where
prettyPrec _ (MkLocalDef k name rhs) = showString ("_" ++ name ++ show k) . showString " := " . prettyPrec 0 rhs
-- | A straightline program encoding the computation of constraints
data StraightLine = MkStraightLine
{ localdefs :: [LocalDef Var_] -- ^ local definitions, in reverse order
, commits :: [Expr_] -- ^ committed constraints, in reverse order
, counter :: Int -- ^ fresh variable counter
}
deriving Show
emptyStraightLine :: StraightLine
emptyStraightLine = MkStraightLine [] [] 0
printStraightLine :: StraightLine -> IO ()
printStraightLine (MkStraightLine{..}) = do
forM_ (reverse localdefs) $ \def -> putStrLn (pretty def)
forM_ (reverse commits ) $ \expr -> putStrLn $ "constraint 0 == " ++ (pretty expr)
compileToStraightLine :: Compute () -> StraightLine
compileToStraightLine = fst . go emptyStraightLine where
go :: StraightLine -> Compute a -> (StraightLine,a)
go state instr = case instr of
Return x -> (state,x)
Bind this rest -> let (state',x) = go state this in go state' (rest x)
Instr this -> case state of
MkStraightLine{..} -> case this of
Commit what -> let state' = MkStraightLine localdefs (what:commits) counter
in (state', ())
Let name rhs -> let def = MkLocalDef counter name rhs
state' = MkStraightLine (def:localdefs) commits (counter+1)
in (state', LocalVar counter name)
--------------------------------------------------------------------------------
type Scope a = IntMap a
emptyScope :: Scope a
emptyScope = IntMap.empty
-- | Run a \"straightline program\", resulting in list of contraints evaluations
runStraightLine :: EvaluationVars FExt -> StraightLine -> [FExt]
runStraightLine = runStraightLine' emptyScope
runStraightLine' :: Scope FExt -> EvaluationVars FExt -> StraightLine -> [FExt]
runStraightLine' iniScope vars (MkStraightLine{..}) = result where
finalScope = foldl' worker iniScope (reverse localdefs)
result = evalConstraints finalScope vars (reverse commits)
worker !scope (MkLocalDef i _ rhs) = IntMap.insert i (evalConstraint scope vars rhs) scope
--------------------------------------------------------------------------------
-- * Evaluation
type Constraint = Expr_
-- | List of all data (one "row") we need to evaluate a gate constraint
--
-- Typically this will be the evaluations of the column polynomials at @zeta@
data EvaluationVars a = MkEvaluationVars
{ local_selectors :: Array Int a -- ^ the selectors
, local_constants :: Array Int a -- ^ the circuit constants
, local_wires :: Array Int a -- ^ the advice wires (witness)
, public_inputs_hash :: [F] -- ^ only used in @PublicInputGate@
}
deriving (Show,Functor)
evalConstraint :: Scope FExt -> EvaluationVars FExt -> Constraint -> FExt
evalConstraint scope (MkEvaluationVars{..}) expr = evalExprWith f expr where
f var = case var of
LocalVar i n -> case IntMap.lookup i scope of
Just y -> y
Nothing -> error $ "variable _" ++ n ++ show i ++ " not in scope"
ProofVar v -> case v of
SelV k -> local_selectors ! k
ConstV k -> local_constants ! k
WireV k -> local_wires ! k
PIV k -> fromBase (public_inputs_hash !! k)
evalConstraints :: Scope FExt -> EvaluationVars FExt -> [Constraint] -> [FExt]
evalConstraints scope vars = map (evalConstraint scope vars)
--------------------------------------------------------------------------------

View File

@ -5,123 +5,92 @@
-- of constraints.
--
{-# LANGUAGE StrictData, RecordWildCards #-}
{-# LANGUAGE StrictData, DeriveFunctor, GADTs, RecordWildCards #-}
module Gate.Constraints where
--------------------------------------------------------------------------------
import Data.Array
import Data.Char
import Prelude hiding ( (^) )
import Data.Array hiding (range)
import Data.Char
import Text.Show
import Data.IntMap (IntMap)
import qualified Data.IntMap as IntMap
import Algebra.Goldilocks
import Algebra.GoldilocksExt
import Algebra.Expr
import Gate.Base
import Gate.Vars
import Gate.Computation
import Gate.Poseidon
--------------------------------------------------------------------------------
-- * Constraint expressions
-- | These index into a row + public input
data Var
= SelV Int -- ^ selector variable
| ConstV Int -- ^ constant variable
| WireV Int -- ^ wire variable
| PIV Int -- ^ public input hash variable
deriving (Eq,Ord,Show)
instance Pretty Var where
prettyPrec _ v = case v of
SelV k -> showString ("s" ++ show k)
ConstV k -> showString ("c" ++ show k)
WireV k -> showString ("w" ++ show k)
PIV k -> showString ("h" ++ show k)
--------------------------------------------------------------------------------
-- | List of all data (one "row") we need to evaluate a gate constraint
--
-- Typically this will be the evaluations of the column polynomials at @zeta@
data EvaluationVars a = MkEvaluationVars
{ local_selectors :: Array Int a -- ^ the selectors
, local_constants :: Array Int a -- ^ the circuit constants
, local_wires :: Array Int a -- ^ the advice wires (witness)
, public_inputs_hash :: [F] -- ^ only used in @PublicInputGate@
}
deriving (Show)
--------------------------------------------------------------------------------
-- * Evaluation
evalExpr :: Expr Var -> EvaluationVars FExt -> FExt
evalExpr expr (MkEvaluationVars{..}) = evalExprWith f expr where
f v = case v of
SelV k -> local_selectors ! k
ConstV k -> local_constants ! k
WireV k -> local_wires ! k
PIV k -> fromBase (public_inputs_hash !! k)
--------------------------------------------------------------------------------
-- * Gate constraints
-- | Returns the (symbolic) constraints for the given gate
gateConstraints :: Gate -> [Expr Var]
gateConstraints gate =
--
gateProgram :: Gate -> StraightLine
gateProgram = compileToStraightLine . gateComputation
gateComputation :: Gate -> Compute ()
gateComputation gate =
case gate of
-- `w[i] - c0*x[i]*y[i] - c1*z[i] = 0`
ArithmeticGate num_ops
-> [ ww (j+3) - cc 0 * ww j * ww (j+1) - cc 1 * ww (j+2) | i<-range num_ops, let j = 4*i ]
-> commitList [ wire (j+3) - cnst 0 * wire j * wire (j+1) - cnst 1 * wire (j+2) | i<-range num_ops, let j = 4*i ]
-- same but consecutive witness variables make up an extension field element
ArithmeticExtensionGate num_ops
-> [ wwExt (j+6) - cc 0 * wwExt j * wwExt (j+2) - cc 1 * wwExt (j+4) | i<-range num_ops, let j = 8*i ]
-> commitList [ wireExt (j+6) - cnst 0 * wireExt j * wireExt (j+2) - cnst 1 * wireExt (j+4) | i<-range num_ops, let j = 8*i ]
-- `sum b^i * limbs[i] - out = 0`, and `0 <= limb[i] < B` is enforced
BaseSumGate num_limbs base
-> let limb i = ww (i+1)
-> let limb i = wire (i+1)
horner = go 0 where go k = if k < num_limbs-1 then limb k + fromIntegral base * go (k+1) else limb k
sum_eq = horner - ww 0
range_eq i = ProdE [ limb i - fromIntegral k | k<-[0..base-1] ]
in sum_eq : [ range_eq i | i<-range num_limbs ]
sum_eq = horner - wire 0
range_eq i = product [ limb i - fromIntegral k | k<-[0..base-1] ]
in commitList $ sum_eq : [ range_eq i | i<-range num_limbs ]
CosetInterpolationGate subgroup_bits coset_degree barycentric_weights
-> todo
-- `c[i] - x[i] = 0`
ConstantGate num_consts
-> [ cc i - ww i | i <- range num_consts ]
-> commitList [ cnst i - wire i | i <- range num_consts ]
-- computes `out = base ^ (sum 2^i e_i)`
-- order of witness variables: [ base, e[0],...,e[n-1], output, t[0]...t[n-1] ]
ExponentiationGate num_power_bits
-> let base = ww 0
exp_bit i = ww (i+1)
out = ww (num_power_bits+1)
-> let base = wire 0
exp_bit i = wire (i+1)
out = wire (num_power_bits+1)
tmp_val 0 = 1
tmp_val i = ww (num_power_bits+1+i)
tmp_val i = wire (num_power_bits+1+i)
cur_bit i = exp_bit (num_power_bits - 1 - i)
eq i = tmp_val (i-1) * (cur_bit i * base + 1 - cur_bit i) - tmp_val i
in [ eq i | i <- range num_power_bits ] ++ [ out - tmp_val (num_power_bits-1) ]
in commitList $ [ eq i | i <- range num_power_bits ] ++ [ out - tmp_val (num_power_bits-1) ]
-- lookups are handled specially, no constraints here
LookupGate num_slots lut_hash -> []
LookupTableGate num_slots lut_hash last_lut_row -> []
LookupGate num_slots lut_hash -> return ()
LookupTableGate num_slots lut_hash last_lut_row -> return ()
-- `z[i] - c0*x[i]*y[i] = 0`, and two witness cells make up an extension field element
MulExtensionGate num_ops
-> [ wwExt (j+4) - cc 0 * wwExt j * wwExt (j+2) | i<-range num_ops, let j = 6*i ]
-> commitList [ wireExt (j+4) - cnst 0 * wireExt j * wireExt (j+2) | i<-range num_ops, let j = 6*i ]
NoopGate -> []
NoopGate -> return ()
-- equality with "hardcoded" hash components
PublicInputGate
-> [ hh i - ww i | i <- range 4 ]
-> commitList [ hash i - wire i | i <- range 4 ]
PoseidonGate hash_width -> case hash_width of
12 -> todo -- poseidonGateConstraints
12 -> poseidonGateConstraints
k -> error ( "gateConstraints/PoseidonGate: unsupported width " ++ show k)
PoseidonMdsGate hash_width -> case hash_width of
@ -143,12 +112,4 @@ gateConstraints gate =
todo = error $ "gateConstraints: gate `" ++ takeWhile isAlpha (show gate) ++ "` not yet implemented"
range k = [0..k-1]
ww i = VarE (WireV i) -- witness variable
cc i = VarE (ConstV i) -- constant variable
hh i = VarE (PIV i) -- public input hash component
wwExt i = ww i + ImagE (ww (i+1)) -- use two consecutive variables as an extension field element
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------

View File

@ -1,5 +1,6 @@
-- | Gates are encoded as strings produced by ad-hoc of modification of Rust textual serialization...
-- | Gates are encoded as strings produced by some ad-hoc modifications
-- of the default Rust textual serialization...
--
-- ... so we have to parse /that/
--

134
src/Gate/Poseidon.hs Normal file
View File

@ -0,0 +1,134 @@
-- | Plonky2's Poseidon gate
--
{-# LANGUAGE StrictData, RecordWildCards #-}
module Gate.Poseidon where
--------------------------------------------------------------------------------
import Data.Array hiding (range)
import Data.Char
import Data.Foldable
import Control.Monad
import Control.Monad.State.Strict
import Algebra.Goldilocks
import Algebra.Expr
import Gate.Vars
import Gate.Computation
import Hash.Constants
--------------------------------------------------------------------------------
sbox :: Expr Var_ -> Compute (Expr Var_)
sbox x = do
x2 <- let_ "x2_" (x *x )
x3 <- let_ "x3_" (x *x2)
x4 <- let_ "x4_" (x2*x2)
return (x3*x4)
flipFoldM :: (Foldable t, Monad m) => b -> t a -> (b -> a -> m b) -> m b
flipFoldM s0 list action = foldM action s0 list
flipFoldM_ :: (Foldable t, Monad m) => b -> t a -> (b -> a -> m b) -> m ()
flipFoldM_ s0 list action = void (foldM action s0 list)
-- | Poseidon state
type PS = [Expr Var_]
--------------------------------------------------------------------------------
poseidonGateConstraints :: Compute ()
poseidonGateConstraints = do
-- merkle swap
let input_lhs i = input i
let input_rhs i = input (i+4)
commit $ (1 - swap_flag) * swap_flag
commitList [ swap_flag * (input_rhs i - input_lhs i) - delta i | i <- range 4 ]
-- swapped inputs
let state0 :: PS
state0
= [ input_lhs i + delta i | i <- [0.. 3] ]
++ [ input_rhs (i-4) - delta (i-4) | i <- [4.. 7] ]
++ [ input i | i <- [8..11] ]
-- initial full rounds
phase1 <- flipFoldM state0 [0..3] $ \state1 r -> do
let state2 = plus_rc r state1
state3 <- if r == 0
then return state2
else do
let sbox_in = initial_sbox_in r
commitList [ (state2!!i) - sbox_in i | i <- range 12 ]
return [ sbox_in i | i <- range 12 ]
state4 <- mapM sbox state3
return $ mds state4
-- partial rounds
let state' = zipWith (+) phase1 (map LitE $ elems fast_PARTIAL_FIRST_ROUND_CONSTANT)
let state'' = mdsInitPartial state'
phase2 <- flipFoldM state'' [0..21] $ \state1 r -> do
let sbox_in = partial_sbox_in r
commit $ (state1!!0) - sbox_in
y <- sbox sbox_in
let z = if r < 21 then y + LitE (fast_PARTIAL_ROUND_CONSTANTS!r) else y
let state2 = z : tail state1
return $ mdsFastPartial r state2
-- final full rounds
phase3 <- flipFoldM phase2 [0..3] $ \state1 r -> do
let state2 = plus_rc (r+26) state1
let sbox_in = final_sbox_in r
commitList [ (state2!!i) - sbox_in i | i <- range 12 ]
state3 <- mapM sbox [ sbox_in i | i <- range 12 ]
return $ mds state3
-- constraint the output to be the result
commitList [ phase3!!i - output i | i <- range 12 ]
where
-- multiply by the MDS matrix
mds :: PS -> PS
mds state =
[ sum [ LitE (mdsMatrixCoeff i j) * x | (j,x) <- zip [0..] state ]
| i <- range 12
]
dotProd :: PS -> [F] -> Expr Var_
dotProd es cs = sum $ zipWith (\e c -> e * LitE c) es cs
mdsInitPartial :: PS -> PS
mdsInitPartial (first:rest)
= first
: [ sum [ LitE (partialMdsMatrixCoeff i j) * x | (j,x) <- zip [0..] rest ]
| i <- range 11
]
mdsFastPartial :: Int -> PS -> PS
mdsFastPartial r state@(s0:rest) = res where
m0 = mdsMatrixCoeff 0 0
cs = m0 : elems (fast_PARTIAL_ROUND_W_HATS!r)
d = dotProd state cs
res = d : [ x + s0 * LitE t | (x,t) <- zip rest (elems $ fast_PARTIAL_ROUND_VS!r) ]
-- add round constants
plus_rc :: Int -> PS -> PS
plus_rc r state = zipWith (+) state (map LitE $ elems (all_ROUND_CONSTANTS!r))
-- witness variables
input i = wire i
output i = wire (i+12)
swap_flag = wire 24
delta i = wire (25+i)
initial_sbox_in r i = wire (29 + 12*(r-1) + i) -- 0 < r < 4 , 0 <= i < 12
partial_sbox_in r = wire (29 + 36 + r) -- 0 <= r < 22
final_sbox_in r i = wire (29 + 36 + 22 + 12*r + i) -- 0 <= r < 4 , 0 <= i < 12
--------------------------------------------------------------------------------

58
src/Gate/Vars.hs Normal file
View File

@ -0,0 +1,58 @@
-- | Variables appearing in gate constraints
module Gate.Vars where
--------------------------------------------------------------------------------
import Text.Show
import Algebra.Expr
import Misc.Pretty
--------------------------------------------------------------------------------
-- * Constraint variables
-- | These index into a row + public input
data PlonkyVar
= SelV Int -- ^ selector variable
| ConstV Int -- ^ constant variable
| WireV Int -- ^ wire variable
| PIV Int -- ^ public input hash variable
deriving (Eq,Ord,Show)
instance Pretty PlonkyVar where
prettyPrec _ v = case v of
SelV k -> showString ("s" ++ show k)
ConstV k -> showString ("c" ++ show k)
WireV k -> showString ("w" ++ show k)
PIV k -> showString ("h" ++ show k)
--------------------------------------------------------------------------------
-- * Variables
data Var v
= LocalVar Int String -- ^ a temporary variable
| ProofVar v -- ^ a proof variable (eg. witness, constant, selector)
deriving (Eq,Show)
instance Pretty v => Pretty (Var v) where
prettyPrec d var = case var of
LocalVar k name -> showString ("_" ++ name ++ show k)
ProofVar v -> prettyPrec d v
--------------------------------------------------------------------------------
-- * Convenience
range :: Int -> [Int]
range k = [0..k-1]
wire, cnst, hash :: Int -> Expr (Var PlonkyVar)
wire i = VarE $ ProofVar $ WireV i -- witness variable
cnst i = VarE $ ProofVar $ ConstV i -- constant variable
hash i = VarE $ ProofVar $ PIV i -- public input hash component
wireExt :: Int -> Expr (Var PlonkyVar)
wireExt i = wire i + ImgE (wire (i+1)) -- use two consecutive variables as an extension field element
--------------------------------------------------------------------------------

View File

@ -40,10 +40,8 @@ fast_PARTIAL_ROUND_CONSTANTS = listArray (0,21)
, 0x1aca78f31c97c876 , 0x0
]
{-
fast_PARTIAL_ROUND_VS :: [Array Int F]
fast_PARTIAL_ROUND_VS = map (listArray (0,10))
fast_PARTIAL_ROUND_VS :: Array Int (Array Int F)
fast_PARTIAL_ROUND_VS = listArray (0,21) $ map (listArray (0,10))
[ [0x94877900674181c3, 0xc6c67cc37a2a2bbd, 0xd667c2055387940f, 0x0ba63a63e94b5ff0, 0x99460cc41b8f079f, 0x7ff02375ed524bb3, 0xea0870b47a8caf0e, 0xabcad82633b7bc9d, 0x3b8d135261052241, 0xfb4515f5e5b0d539, 0x3ee8011c2b37f77c ]
, [0x0adef3740e71c726, 0xa37bf67c6f986559, 0xc6b16f7ed4fa1b00, 0x6a065da88d8bfc3c, 0x4cabc0916844b46f, 0x407faac0f02e78d1, 0x07a786d9cf0852cf, 0x42433fb6949a629a, 0x891682a147ce43b0, 0x26cfd58e7b003b55, 0x2bbf0ed7b657acb3 ]
, [0x481ac7746b159c67, 0xe367de32f108e278, 0x73f260087ad28bec, 0x5cfc82216bc1bdca, 0xcaccc870a2663a0e, 0xdb69cd7b4298c45d, 0x7bc9e0c57243e62d, 0x3cc51c5d368693ae, 0x366b4e8cc068895b, 0x2bd18715cdabbca4, 0xa752061c4f33b8cf ]
@ -68,8 +66,8 @@ fast_PARTIAL_ROUND_VS = map (listArray (0,10))
, [0x0000000000000014, 0x0000000000000022, 0x0000000000000012, 0x0000000000000027, 0x000000000000000d, 0x000000000000000d, 0x000000000000001c, 0x0000000000000002, 0x0000000000000010, 0x0000000000000029, 0x000000000000000f ]
]
fast_PARTIAL_ROUND_W_HATS :: [Array Int F]
fast_PARTIAL_ROUND_W_HATS = map (listArray (0,10))
fast_PARTIAL_ROUND_W_HATS :: Array Int (Array Int F)
fast_PARTIAL_ROUND_W_HATS = listArray (0,21) $ map (listArray (0,10))
[ [0x3d999c961b7c63b0, 0x814e82efcd172529, 0x2421e5d236704588, 0x887af7d4dd482328, 0xa5e9c291f6119b27, 0xbdc52b2676a4b4aa, 0x64832009d29bcf57, 0x09c4155174a552cc, 0x463f9ee03d290810, 0xc810936e64982542, 0x043b1c289f7bc3ac ]
, [0x673655aae8be5a8b, 0xd510fe714f39fa10, 0x2c68a099b51c9e73, 0xa667bfa9aa96999d, 0x4d67e72f063e2108, 0xf84dde3e6acda179, 0x40f9cc8c08f80981, 0x5ead032050097142, 0x6591b02092d671bb, 0x00e18c71963dd1b7, 0x8a21bcd24a14218a ]
, [0x202800f4addbdc87, 0xe4b5bdb1cc3504ff, 0xbe32b32a825596e7, 0x8e0f68c5dc223b9a, 0x58022d9e1c256ce3, 0x584d29227aa073ac, 0x8b9352ad04bef9e7, 0xaead42a3f445ecbf, 0x3c667a1d833a3cca, 0xda6f61838efa1ffe, 0xe8f749470bd7c446 ]
@ -94,8 +92,6 @@ fast_PARTIAL_ROUND_W_HATS = map (listArray (0,10))
, [0x3abeb80def61cc85, 0x9d19c9dd4eac4133, 0x075a652d9641a985, 0x9daf69ae1b67e667, 0x364f71da77920a18, 0x50bd769f745c95b1, 0xf223d1180dbbf3fc, 0x2f885e584e04aa99, 0xb69a0fa70aea684a, 0x09584acaa6e062a0, 0x0bc051640145b19b ]
]
-}
-- ^ NB: This is in ROW-major order to support cache-friendly pre-multiplication.
fast_PARTIAL_ROUND_INITIAL_MATRIX :: Array (Int,Int) F
fast_PARTIAL_ROUND_INITIAL_MATRIX = listArray ((0,0),(10,10)) $ concat
@ -112,6 +108,9 @@ fast_PARTIAL_ROUND_INITIAL_MATRIX = listArray ((0,0),(10,10)) $ concat
, [0xd841e8ef9dde8ba0, 0x156048ee7a738154, 0x85418a9fef8a9890, 0x64dd936da878404d, 0x726af914971c1374, 0x7f8e41e0b0a6cdff, 0xf97abba0dffb6c50, 0xf4a437f2888ae909, 0xdcedab70f40718ba, 0xe796d293a47a64cb, 0x80772dc2645b280b ]
]
partialMdsMatrixCoeff :: Int -> Int -> F
partialMdsMatrixCoeff i j = fast_PARTIAL_ROUND_INITIAL_MATRIX ! (j,i)
partition12 :: [a] -> [[a]]
partition12 = go where
go [] = []

15
src/Misc/Pretty.hs Normal file
View File

@ -0,0 +1,15 @@
-- | Precedence-aware pretty-printing
module Misc.Pretty where
--------------------------------------------------------------------------------
-- | See "Text.Show"
class Pretty a where
prettyPrec :: Int -> a -> (String -> String)
pretty :: Pretty a => a -> String
pretty x = prettyPrec 0 x ""
--------------------------------------------------------------------------------

82
src/Test/Witness.hs Normal file
View File

@ -0,0 +1,82 @@
-- | We can test the gate constraints on actual witness rows during development
{-# LANGUAGE StrictData, DeriveGeneric, DeriveAnyClass, RecordWildCards #-}
module Test.Witness where
--------------------------------------------------------------------------------
import Data.Array
import Data.Aeson
import GHC.Generics
import qualified Data.ByteString.Lazy as L
import Algebra.Goldilocks
import Algebra.GoldilocksExt
import Gate.Base
import Gate.Computation
import Gate.Constraints
--------------------------------------------------------------------------------
-- | As exported into JSON by our Plonky2 fork
data Witness = MkWitness
{ gates :: [String]
, selector_vector :: [Int]
, selector_columns :: [[F]]
, constants_columns :: [[F]]
, matrix :: [[F]]
}
deriving (Show,Generic,ToJSON,FromJSON)
matrixRow :: [[F]] -> Int -> [F]
matrixRow matrix i = map (!!i) matrix
witnessRow :: Witness -> Int -> EvaluationVars F
witnessRow (MkWitness{..}) i = MkEvaluationVars
{ local_selectors = toArray (matrixRow selector_columns i)
, local_constants = toArray (matrixRow constants_columns i)
, local_wires = toArray (matrixRow matrix i)
, public_inputs_hash = []
}
where
toArray xs = listArray (0, length xs - 1) xs
loadWitness :: FilePath -> IO Witness
loadWitness fpath = do
txt <- L.readFile fpath
case decode txt of
Just w -> return w
Nothing -> fail "loadWitness: cannot parse JSON witness"
--------------------------------------------------------------------------------
test_fibonacci = do
witness <- loadWitness "../json/fibonacci_witness.json"
let arith_row = witnessRow witness 0
let posei_row = witnessRow witness 5
let pubio_row = witnessRow witness 6
let const_row = witnessRow witness 7
let arith_prg = gateProgram (ArithmeticGate 20)
let posei_prg = gateProgram (PoseidonGate 12)
let pubio_prg = gateProgram (PublicInputGate )
let const_prg = gateProgram (ConstantGate 2)
let arith_evals = runStraightLine (fmap fromBase arith_row) arith_prg
let const_evals = runStraightLine (fmap fromBase const_row) const_prg
let posei_evals = runStraightLine (fmap fromBase posei_row) posei_prg
putStrLn $ "number of constants in ArithmeticGate = " ++ show (length arith_evals)
putStrLn $ "number of constants in ConstantGate = " ++ show (length const_evals)
putStrLn $ "number of constants in PosiedoGate = " ++ show (length posei_evals)
print arith_evals
print const_evals
print posei_evals
--------------------------------------------------------------------------------