From 644832ec4838168b875e7ebc99ce98fbc6f3ee44 Mon Sep 17 00:00:00 2001 From: Balazs Komuves Date: Thu, 23 Jan 2025 18:59:51 +0100 Subject: [PATCH] implement CosetInterpolationGate --- README.md | 11 ++- src/Algebra/Goldilocks.hs | 10 ++ src/Gate/Base.hs | 2 +- src/Gate/Constraints.hs | 28 ++++-- src/Gate/Custom/CosetInterp.hs | 127 ++++++++++++++++++++++++++ src/Gate/{ => Custom}/Poseidon.hs | 2 +- src/Gate/{ => Custom}/RandomAccess.hs | 22 ++--- src/Misc/Aux.hs | 20 ++++ 8 files changed, 195 insertions(+), 27 deletions(-) create mode 100644 src/Gate/Custom/CosetInterp.hs rename src/Gate/{ => Custom}/Poseidon.hs (99%) rename src/Gate/{ => Custom}/RandomAccess.hs (82%) diff --git a/README.md b/README.md index 4263107..39c950a 100644 --- a/README.md +++ b/README.md @@ -8,8 +8,8 @@ This is a (WIP) implementation of a Plonky2 verifier written in Haskell. system developed by Polygon Zero, optimized for recursive proofs. The goal here is to provide an executable specification (along a with less precise, -but [still detailed](commentary/Overview.md) human language description) of the Plonky2 verification -algorithm. +but [still detailed](commentary/Overview.md) human language description) of +the Plonky2 verification algorithm. Another goal is to be a basis for further tooling (for example: estimating verifier costs, helping the design of recursive circuits, generating @@ -22,20 +22,21 @@ try to focus on simplicity. ### Implementation status - [x] Parsing the proof and verification key from JSON -- [ ] Parsing from Plonky's custom binary serialization +- [ ] Parsing from Plonky2's custom binary serialization - [x] Generating verifier challenges - [ ] Recursive circuit subtle details (like [this](https://github.com/0xPolygonZero/plonky2/blob/356aefb6863ac881fb71f9bf851582c915428458/plonky2/src/fri/challenges.rs#L55-L64])) - [x] Constraints check - [ ] FRI check - [ ] Support lookup tables - [x] Documenting Plonky2 internals and the verifier algorithm (WIP) +- [ ] Cabalize Supported gates: - [x] ArithmeticGate - [x] ArithmeticExtensionGate - [x] BaseSumGate -- [ ] CosetInterpolationGate +- [x] CosetInterpolationGate - [x] ConstantGate - [x] ExponentiationGate - [ ] LookupGate @@ -45,7 +46,7 @@ Supported gates: - [x] PublicInputGate - [x] PoseidonGate - [ ] PoseidonMdsGate -- [X] RandomAccessGate +- [x] RandomAccessGate - [ ] ReducingGate - [ ] ReducingExtensionGate diff --git a/src/Algebra/Goldilocks.hs b/src/Algebra/Goldilocks.hs index 3f04fb1..701f243 100644 --- a/src/Algebra/Goldilocks.hs +++ b/src/Algebra/Goldilocks.hs @@ -11,6 +11,7 @@ import qualified Prelude import Data.Bits import Data.Word +import Data.List import Data.Ratio import Data.Array @@ -23,6 +24,7 @@ import GHC.Generics import Data.Aeson ( ToJSON(..), FromJSON(..) ) import Misc.Pretty +import Misc.Aux -------------------------------------------------------------------------------- @@ -68,6 +70,14 @@ rootsOfUnity = listArray (0,32) $ reverse $ go twoAdicGen where go 1 = [1] go x = x : go (x*x) +subgroupGenerator :: Log2 -> F +subgroupGenerator (Log2 k) = rootsOfUnity!k + +enumerateSubgroup :: Log2 -> [F] +enumerateSubgroup logSize = scanl' (\x _ -> x*g) 1 [1..n-1] where + g = subgroupGenerator logSize + n = exp2 logSize + -------------------------------------------------------------------------------- newtype Goldilocks diff --git a/src/Gate/Base.hs b/src/Gate/Base.hs index ee74802..641077a 100644 --- a/src/Gate/Base.hs +++ b/src/Gate/Base.hs @@ -28,7 +28,7 @@ data Gate = ArithmeticGate { num_ops :: Int } | ArithmeticExtensionGate { num_ops :: Int } | BaseSumGate { num_limbs :: Int , base :: Int } - | CosetInterpolationGate { subgroup_bits :: Int, coset_degree :: Int , barycentric_weights :: [F] } + | CosetInterpolationGate { subgroup_bits :: Int, constr_degree :: Int , barycentric_weights :: [F] } | ConstantGate { num_consts :: Int } | ExponentiationGate { num_power_bits :: Int } | LookupGate { num_slots :: Int, lut_hash :: KeccakHash } diff --git a/src/Gate/Constraints.hs b/src/Gate/Constraints.hs index a6dbfc0..7df69c6 100644 --- a/src/Gate/Constraints.hs +++ b/src/Gate/Constraints.hs @@ -22,9 +22,9 @@ import Algebra.Expr import Gate.Base import Gate.Vars import Gate.Computation -import Gate.Poseidon -import Gate.RandomAccess - +import Gate.Custom.Poseidon +import Gate.Custom.RandomAccess +import Gate.Custom.CosetInterp import Misc.Aux -------------------------------------------------------------------------------- @@ -60,8 +60,8 @@ gateComputation gate = range_eq i = product [ limb i - fromIntegral k | k<-[0..base-1] ] in commitList $ sum_eq : [ range_eq i | i<-range num_limbs ] - CosetInterpolationGate subgroup_bits coset_degree barycentric_weights - -> todo + CosetInterpolationGate subgroup_bits constr_degree barycentric_weights + -> cosetInterpolationGateConstraints $ MkCICfg (Log2 subgroup_bits) constr_degree barycentric_weights -- `c[i] - x[i] = 0` ConstantGate num_consts @@ -97,7 +97,7 @@ gateComputation gate = k -> error ( "gateConstraints/PoseidonMdsGate: unsupported width " ++ show k) RandomAccessGate num_bits num_copies num_extra_constants - -> randomAccessGateConstraints (MkRACfg num_bits num_copies num_extra_constants) + -> randomAccessGateConstraints (MkRACfg (Log2 num_bits) num_copies num_extra_constants) ReducingGate num_coeffs -> todo @@ -132,9 +132,19 @@ exponentiationGateConstraints num_power_bits = sqr x = x*x -------------------------------------------------------------------------------- +-- * Debugging -testArtihExtGate = runComputation testEvaluationVarsExt (gateComputation (ArithmeticExtensionGate 10)) -testMulExtGate = runComputation testEvaluationVarsExt (gateComputation (MulExtensionGate 13)) -testExpoGate = runComputation testEvaluationVarsExt (gateComputation (ExponentiationGate 13)) +testCompute :: Compute () -> [FExt] +testCompute = runComputation testEvaluationVarsExt + +testArtihExtGate = testCompute $ gateComputation (ArithmeticExtensionGate 10) +testBaseSum2 = testCompute $ gateComputation (BaseSumGate 13 2) +testBaseSum3 = testCompute $ gateComputation (BaseSumGate 13 3) +testExpoGate = testCompute $ gateComputation (ExponentiationGate 13) +testMulExtGate = testCompute $ gateComputation (MulExtensionGate 13) +testCosetGate3 = testCompute $ cosetInterpolationGateConstraints $ cosetInterpolationGateConfig (Log2 3) +testCosetGate4 = testCompute $ cosetInterpolationGateConstraints $ cosetInterpolationGateConfig (Log2 4) +testCosetGate5 = testCompute $ cosetInterpolationGateConstraints $ cosetInterpolationGateConfig (Log2 5) +testRandAccGate = testCompute $ randomAccessGateConstraints $ randomAccessGateConfig (Log2 4) -------------------------------------------------------------------------------- diff --git a/src/Gate/Custom/CosetInterp.hs b/src/Gate/Custom/CosetInterp.hs new file mode 100644 index 0000000..31a69f8 --- /dev/null +++ b/src/Gate/Custom/CosetInterp.hs @@ -0,0 +1,127 @@ + +-- | The @CosetInterpolation@ gate + +{-# LANGUAGE StrictData, RecordWildCards #-} +module Gate.Custom.CosetInterp where + +-------------------------------------------------------------------------------- + +import Data.Foldable +import Control.Monad + +import Algebra.Goldilocks +import Algebra.GoldilocksExt +import Algebra.Expr + +import Gate.Vars +import Gate.Computation + +import Misc.Aux + +-------------------------------------------------------------------------------- + +data CosetInterpolationGateConfig = MkCICfg + { ci_subgroup_bits :: Log2 -- ^ logarithm of the size of the subgroup (typically 4) + , ci_degree :: Int -- ^ equation degree (?) + , ci_barycentric_weights :: [F] -- ^ barycentric weights + } + deriving Show + +cosetInterpolationGateConfig :: Log2 -> CosetInterpolationGateConfig +cosetInterpolationGateConfig subgroup_bits = ci_cfg where + ci_cfg = MkCICfg + { ci_subgroup_bits = subgroup_bits + , ci_degree = degree + , ci_barycentric_weights = calcBarycentricWeights coset + } + max_degree = 8 + n_points = exp2 subgroup_bits + n_intermed_guess = (n_points - 2) `Prelude.div` (max_degree - 1) -- ??? + degree = (n_points - 2) `Prelude.div` (n_intermed_guess + 1) + 2 -- ??? + coset = enumerateSubgroup subgroup_bits + +-- | See +calcBarycentricWeights :: [F] -> [F] +calcBarycentricWeights locations = weights where + weights = map recip $ map f $ select1 locations + f (x,ys) = product [ x - y | y<-ys ] + +-------------------------------------------------------------------------------- + +cosetInterpolationGateConstraints :: CosetInterpolationGateConfig -> Compute () +cosetInterpolationGateConstraints (MkCICfg{..}) = do + + let MkExt u v = eval_loc - scaleExt coset_shift shifted_loc + commitList [ u , v ] + + let initials = initial : [ (tmp_eval i , tmp_prod i) | i <- [0..n_intermediates-1] ] + let chunks = zip3 chunked_domain chunked_values chunked_weights + + let worker ini (d,v,w) = partial_interpolate d v w ini + let stuff = zipWith worker initials chunks + + forM_ (zip [0..] (init stuff)) $ \(i,(eval,prod)) -> do + let MkExt u1 v1 = tmp_eval i - eval + let MkExt u2 v2 = tmp_prod i - prod + commitList [ u1 , v1 , u2 , v2 ] + + let (final_eval,_) = last stuff + let MkExt u v = eval_result - final_eval + commitList [ u , v ] + + where + + max_degree = 8 + degree = ci_degree + n_points = exp2 ci_subgroup_bits + n_intermediates = (n_points - 2) `Prelude.div` (ci_degree - 1) + + domain = enumerateSubgroup ci_subgroup_bits :: [F] + values = [ poly_value k | k <- range n_points ] :: [Ext Expr_] + + chunked_domain = chunk domain + chunked_values = chunk values + chunked_weights = chunk ci_barycentric_weights + + -- witness variables + coset_shift = wire 0 :: Expr_ + poly_value k = wireExt $ 1 + 2*k :: Ext Expr_ + eval_loc = wireExt $ 1 + 2*n_points :: Ext Expr_ + eval_result = wireExt $ 1 + 2*n_points + 2 :: Ext Expr_ + tmp_eval i = wireExt $ 1 + 2*(n_points+2) + 2*i :: Ext Expr_ + tmp_prod i = wireExt $ 1 + 2*(n_points+2) + 2*(n_intermediates + i) :: Ext Expr_ + shifted_loc = wireExt $ 1 + 2*(n_points+2) + 4* n_intermediates :: Ext Expr_ + + initial = (MkExt 0 0 , MkExt 1 0) :: (Ext Expr_, Ext Expr_) + + -- we use this formula + -- but in 1) chunks and 2) iteratively + -- + -- this is what happens: + -- + -- run0 = 0 + -- run1 = run0*(x - x1) + val1*1 + -- run2 = run1*(x - x2) + val2*(x - x1) + -- run3 = run2*(x - x3) + val3*(x - x1)*(x - x2) + -- run4 = run3*(x - x4) + val4*(x - x1)*(x - x2)*(x - x3) + -- run5 = run4*(x - x5) + val5*(x - x1)*(x - x2)*(x - x3)*(x - x4) + -- + -- this computes barycentric formula (val_i already conntains the barycentric weights) + -- + partial_interpolate :: [F] -> [Ext Expr_] -> [F] -> (Ext Expr_, Ext Expr_) -> (Ext Expr_, Ext Expr_) + partial_interpolate domain values weights ini = result where + weighted = zipWith scaleExt (map LitE weights) values + result = foldl f ini (zip weighted domain) + x0 = shifted_loc + + f :: (Ext Expr_, Ext Expr_) -> (Ext Expr_, F) -> (Ext Expr_, Ext Expr_) + f (eval,prod) (val, xi) = (next_eval,next_prod) where + term = x0 - fromBase (LitE xi) + next_eval = term * eval + val * prod + next_prod = term * prod + + -- the first chunk has degree one less so it can have length one more... + chunk xs = take degree xs : partition (degree-1) (drop degree xs) + +-------------------------------------------------------------------------------- + diff --git a/src/Gate/Poseidon.hs b/src/Gate/Custom/Poseidon.hs similarity index 99% rename from src/Gate/Poseidon.hs rename to src/Gate/Custom/Poseidon.hs index 0fc4351..8f4d3ff 100644 --- a/src/Gate/Poseidon.hs +++ b/src/Gate/Custom/Poseidon.hs @@ -3,7 +3,7 @@ -- {-# LANGUAGE StrictData, RecordWildCards #-} -module Gate.Poseidon where +module Gate.Custom.Poseidon where -------------------------------------------------------------------------------- diff --git a/src/Gate/RandomAccess.hs b/src/Gate/Custom/RandomAccess.hs similarity index 82% rename from src/Gate/RandomAccess.hs rename to src/Gate/Custom/RandomAccess.hs index 98a27a8..d8f1a33 100644 --- a/src/Gate/RandomAccess.hs +++ b/src/Gate/Custom/RandomAccess.hs @@ -2,7 +2,7 @@ -- | The @RandomAccess@ gate {-# LANGUAGE StrictData, RecordWildCards #-} -module Gate.RandomAccess where +module Gate.Custom.RandomAccess where -------------------------------------------------------------------------------- @@ -24,15 +24,15 @@ num_routed_columns = 80 :: Int -------------------------------------------------------------------------------- data RandomAccessGateConfig = MkRACfg - { ra_num_bits :: Int -- ^ number of bits in the index (so the vector has width @2^n@) + { ra_num_bits :: Log2 -- ^ number of bits in the index (so the vector has width @2^n@) , ra_num_copies :: Int -- ^ how many copies of this operation is included in a row , ra_num_extra_constants :: Int -- ^ number of extra cells used as in ConstantGate } deriving Show -randomAccessGateConfig :: Int -> RandomAccessGateConfig +randomAccessGateConfig :: Log2 -> RandomAccessGateConfig randomAccessGateConfig num_bits = ra_cfg where - veclen = 2 ^ num_bits + veclen = exp2 num_bits width = 2 + veclen copies = Prelude.div num_routed_columns width extra = min 2 (num_routed_columns - copies*width) @@ -50,22 +50,24 @@ randomAccessGateConstraints (MkRACfg{..}) = do forM_ [0..ra_num_copies-1] $ \k -> do -- index bits are actual bits - commitList [ bits k j * (bits k j - 1) | j<-[0..ra_num_bits-1] ] + commitList [ bits k j * (bits k j - 1) | j<-[0..num_bits-1] ] -- bit decomposition is correct - let reconstr = foldr (\b acc -> 2*acc + b) 0 [ bits k j | j<-[0..ra_num_bits-1] ] + let reconstr = foldr (\b acc -> 2*acc + b) 0 [ bits k j | j<-[0..num_bits-1] ] commit $ reconstr - index k let lkp_val = lookup_eq - [ bits k j | j<-[0..ra_num_bits-1] ] + [ bits k j | j<-[0..num_bits-1] ] [ inputs k i | i<-[0..veclen-1] ] commit $ lkp_val - output k forM_ [0..ra_num_extra_constants-1] $ \j -> commit (cnst j - extra j) where + + Log2 num_bits = ra_num_bits - veclen = 2 ^ ra_num_bits + veclen = exp2 ra_num_bits width = 2 + veclen bits_start_at = width * ra_num_copies + ra_num_extra_constants @@ -75,7 +77,7 @@ randomAccessGateConstraints (MkRACfg{..}) = do output k = wire $ k*width + 1 inputs k j = wire $ k*width + 2 + j extra j = wire $ ra_num_copies * width + j - bits k j = wire $ bits_start_at + k*ra_num_bits + j + bits k j = wire $ bits_start_at + k*num_bits + j into_pairs [] = [] into_pairs (x:y:rest) = (x,y) : into_pairs rest @@ -87,5 +89,3 @@ randomAccessGateConstraints (MkRACfg{..}) = do -------------------------------------------------------------------------------- -testRAGate = runComputation testEvaluationVarsExt (randomAccessGateConstraints $ randomAccessGateConfig 4) - diff --git a/src/Misc/Aux.hs b/src/Misc/Aux.hs index 2dbaef8..446f072 100644 --- a/src/Misc/Aux.hs +++ b/src/Misc/Aux.hs @@ -7,6 +7,7 @@ module Misc.Aux where -------------------------------------------------------------------------------- import Data.Array +import Data.Bits import Data.List import Data.Aeson hiding ( Array , pairs ) @@ -14,6 +15,18 @@ import GHC.Generics -------------------------------------------------------------------------------- +newtype Log2 + = Log2 Int + deriving (Eq,Ord,Show,Num) + +fromLog2 :: Log2 -> Int +fromLog2 (Log2 k) = k + +exp2 :: Log2 -> Int +exp2 (Log2 k) = shiftL 1 k + +-------------------------------------------------------------------------------- + range :: Int -> [Int] range k = [0..k-1] @@ -50,6 +63,13 @@ partition k = go where go [] = [] go xs = take k xs : go (drop k xs) +-- | all possible ways to select 1 element out of a (nonempy) list +select1 :: [a] -> [(a,[a])] +select1 [] = error "select1: empty list" +select1 zs = go zs where + go [x] = [(x,[])] + go (x:xs) = (x,xs) : map (\(y,ys) -> (y,x:ys)) (go xs) + -------------------------------------------------------------------------------- listToArray :: [a] -> Array Int a