From 49b5a0035e1995da2765e6226e4f26db438b8f45 Mon Sep 17 00:00:00 2001 From: Balazs Komuves Date: Sat, 8 Feb 2025 17:51:21 +0100 Subject: [PATCH] FRI verification seems to work (finally...) --- README.md | 4 + src/Algebra/FFT.hs | 95 ++++++++ src/Algebra/Goldilocks.hs | 10 + src/Algebra/GoldilocksExt.hs | 8 + src/Challenge/FRI.hs | 12 +- src/Hash/Merkle.hs | 45 ++++ src/Hash/Sponge.hs | 5 - src/Misc/Aux.hs | 61 +++++- src/Plonk/FRI.hs | 410 +++++++++++++++++++++++++++++++++++ src/Plonk/Verifier.hs | 13 ++ src/Types.hs | 37 +++- src/testmain.hs | 16 +- 12 files changed, 684 insertions(+), 32 deletions(-) create mode 100644 src/Algebra/FFT.hs create mode 100644 src/Hash/Merkle.hs create mode 100644 src/Plonk/FRI.hs diff --git a/README.md b/README.md index f704cce..a262df9 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,9 @@ Plonky2 verifier circuits for other proof systems, etc) Note: It's deliberately not a goal for this verifier to be efficient; instead we try to focus on simplicity. +This code can be a bit ugly at some places. There is a reason for that, namely, +that the semantics we have to emulate, are _somewhat idiosyncratic_... + ### Implementation status @@ -28,6 +31,7 @@ try to focus on simplicity. - [x] Constraints check - [ ] FRI check - [x] Support lookup tables +- [ ] Support zero-knowledge - [x] Documenting Plonky2 internals and the verifier algorithm (WIP) - [ ] Cabalize diff --git a/src/Algebra/FFT.hs b/src/Algebra/FFT.hs new file mode 100644 index 0000000..7ad5ca5 --- /dev/null +++ b/src/Algebra/FFT.hs @@ -0,0 +1,95 @@ + +{-# LANGUAGE BangPatterns, StrictData #-} +module Algebra.FFT where + +-------------------------------------------------------------------------------- + +import Data.Array +import Data.List +import Data.Bits +import Data.Word + +import Algebra.Goldilocks +import Algebra.GoldilocksExt + +import Misc.Aux + +-------------------------------------------------------------------------------- + +-- | Reverse the order of bits in an n-bit word +reverseBits :: Log2 -> Word64 -> Word64 +reverseBits (Log2 n) w = foldl' (.|.) 0 + [ shiftL ((shiftR w k) .&. 1) (n-k-1) | k<-[0..n-1] ] + +reverseBitsInt :: Log2 -> Int -> Int +reverseBitsInt log2 = fromIntegral . reverseBits log2 . fromIntegral + +reverseIndexBitsNaive :: Array Int a -> Array Int a +reverseIndexBitsNaive arr1 = arr2 where + (0,n1) = bounds arr1 + log2@(Log2 k) = safeLog2 (n1 + 1) + arr2 = array (0,n1) [ (reverseBitsInt log2 i , x) | (i,x) <- assocs arr1 ] + +reverseIndexBits :: Array Int a -> Array Int a +reverseIndexBits = reverseIndexBitsNaive + +reverseIndexBitsList :: [a] -> [a] +reverseIndexBitsList = elems . reverseIndexBitsNaive . listToArray + +-------------------------------------------------------------------------------- + +powersOf' :: Num a => a -> a -> [a] +powersOf' !start !g = go start where go !x = x : go (g*x) + +powersOf :: Num a => a -> [a] +powersOf = powersOf' 1 + +-------------------------------------------------------------------------------- + +{- +ifft :: [FExt] -> [FExt] +ifft xs = go kk xs where + nn = length xs + kk = safeLog2 nn + g = subgroupGenerator kk + ginv = recip g + hs = powersOf' half ginv + half = 1 / 2 :: F + + go (Log2 0) [x] = [ x ] + go (Log2 1) [x,y] = [ (x+y)/2 , (x-y)/2 ] + go (Log2 k) input = + case splitAt halfN input of + (xs,ys) -> go km1 (zipWith f1 xs ys ) ++ + go km1 (zipWith3 f2 hs xs ys) + where + f1 !x !y = scaleExt half (x + y) + f2 !h !x !y = scaleExt h (x - y) + halfN = exp2 km1 + km1 = Log2 k - 1 +-} + +-------------------------------------------------------------------------------- + +naiveFFT :: [FExt] -> [FExt] +naiveFFT xs = ys where + nn = length xs + kk = safeLog2 nn + g = subgroupGenerator kk + ys = [ sum [ scaleExt (g^(j*k)) (xs!!j) | j<-[0..nn-1] ] + | k <- [0..nn-1] + ] + +naiveIFFT :: [FExt] -> [FExt] +naiveIFFT xs = ys where + nn = length xs + kk = safeLog2 nn + g = subgroupGenerator kk + ginv = recip g + fn = toF (fromIntegral nn) + ys = map (scaleExt (recip fn)) + [ sum [ scaleExt (ginv^(j*k)) (xs!!j) | j<-[0..nn-1] ] + | k <- [0..nn-1] + ] + +-------------------------------------------------------------------------------- diff --git a/src/Algebra/Goldilocks.hs b/src/Algebra/Goldilocks.hs index 701f243..5568bba 100644 --- a/src/Algebra/Goldilocks.hs +++ b/src/Algebra/Goldilocks.hs @@ -132,6 +132,9 @@ modp a = mod a goldilocksPrime mkGoldilocks :: Integer -> Goldilocks mkGoldilocks = Goldilocks . modp +mulGen :: F +mulGen = mkGoldilocks 0xc65c18b67785d900 + -------------------------------------------------------------------------------- neg :: Goldilocks -> Goldilocks @@ -173,3 +176,10 @@ pow x e -------------------------------------------------------------------------------- +-- | @sum alpha^i * x_i@ +reduceWithPowers :: Num a => a -> [a] -> a +reduceWithPowers alpha xs = go xs where + go [] = 0 + go (!x:xs) = x + alpha * go xs + +-------------------------------------------------------------------------------- diff --git a/src/Algebra/GoldilocksExt.hs b/src/Algebra/GoldilocksExt.hs index 4b32645..0d1ef98 100644 --- a/src/Algebra/GoldilocksExt.hs +++ b/src/Algebra/GoldilocksExt.hs @@ -26,6 +26,7 @@ type FExt = GoldilocksExt -------------------------------------------------------------------------------- +-- | The ring @R[X] / (X^2-7)@ data Ext a = MkExt !a !a deriving Eq @@ -99,6 +100,13 @@ powExt x e -------------------------------------------------------------------------------- +flattenExt :: [Ext a] -> [a] +flattenExt = go where + go ((MkExt x y) : rest) = x : y : go rest + go [] = [] + +-------------------------------------------------------------------------------- + rndExt :: IO FExt rndExt = do x <- rndF diff --git a/src/Challenge/FRI.hs b/src/Challenge/FRI.hs index 08f2ae0..a9f9d54 100644 --- a/src/Challenge/FRI.hs +++ b/src/Challenge/FRI.hs @@ -14,11 +14,9 @@ import Algebra.Goldilocks import Algebra.GoldilocksExt import Hash.Sponge import Hash.Digest -import Types import Challenge.Monad - --- import Debug.Trace --- debug x y z = trace ("\n - " ++ x ++ " -> " ++ show y) z +import Types +import Misc.Aux -------------------------------------------------------------------------------- @@ -26,7 +24,7 @@ import Challenge.Monad data FriChallenges = MkFriChallenges { fri_alpha :: FExt -- ^ Scaling factor to combine polynomials. , fri_betas :: [FExt] -- ^ Betas used in the FRI commit phase reductions. - , fri_pow_response :: F -- ^ proof-of-work "response" + , fri_pow_response :: F -- ^ proof-of-work \"response\" , fri_query_indices :: [Int] -- ^ Indices at which the oracle is queried in FRI. } deriving (Eq,Show) @@ -44,7 +42,7 @@ newtype FriOpeningBatch absorbFriOpenings :: FriOpenings -> Duplex () absorbFriOpenings (MkFriOpenings batches) = mapM_ (absorb . values) batches --- | Just reordering and concatenating things... +-- | Just /reordering/ and concatenating things... toFriOpenings :: OpeningSet -> FriOpenings toFriOpenings (MkOpeningSet{..}) = MkFriOpenings [ batch_this, batch_next ] where @@ -92,7 +90,7 @@ friChallenges common_data verifier_data proof = do pow_response <- squeeze -- query indices - let lde_size = shiftL 1 (degree_bits + fri_rate_bits fri_config) + let lde_size = exp2' (degree_bits + fri_rate_bits fri_config) let num_fri_queries = fri_num_query_rounds fri_config let f :: F -> Int f felt = fromInteger (mod (asInteger felt) lde_size) diff --git a/src/Hash/Merkle.hs b/src/Hash/Merkle.hs new file mode 100644 index 0000000..ab9741f --- /dev/null +++ b/src/Hash/Merkle.hs @@ -0,0 +1,45 @@ + +{-# LANGUAGE BangPatterns, StrictData, RecordWildCards #-} +module Hash.Merkle where + +-------------------------------------------------------------------------------- + +import Data.Bits + +import Algebra.Goldilocks + +import Hash.Digest +import Hash.Poseidon +import Hash.Sponge + +import Types +import Misc.Aux + +-------------------------------------------------------------------------------- + +-- | Compression function for Merkle trees +compress :: Digest -> Digest -> Digest +compress x y = extractDigest $ permutation $ listToState s0 where + s0 = digestToList x ++ digestToList y ++ [0,0,0,0] + +-------------------------------------------------------------------------------- + +reconstructMerkleRoot :: [F] -> (Int,MerkleProof) -> (Int,Digest) +reconstructMerkleRoot leaf = reconstructMerkleRoot' (sponge leaf) + +reconstructMerkleRoot' :: Digest -> (Int,MerkleProof) -> (Int,Digest) +reconstructMerkleRoot' leaf_digest (leaf_idx, MkMerkleProof{..}) = go leaf_idx leaf_digest siblings where + go !idx !leaf [] = (idx, leaf) + go !idx !leaf (this:rest) = go idx' leaf' rest where + idx' = shiftR idx 1 + leaf' = if isEven idx + then compress leaf this + else compress this leaf + +checkMerkleProof :: MerkleCap -> Int -> [F] -> MerkleProof -> Bool +checkMerkleProof cap idx leaf proof = (cap_roots!!rootidx == root) where + MkMerkleCap cap_roots = cap + (rootidx, root) = reconstructMerkleRoot leaf (idx,proof) + +-------------------------------------------------------------------------------- + diff --git a/src/Hash/Sponge.hs b/src/Hash/Sponge.hs index abc44a7..26afd0e 100644 --- a/src/Hash/Sponge.hs +++ b/src/Hash/Sponge.hs @@ -41,10 +41,5 @@ spongeWithPad what = go zeroState (what ++ [1]) where then listToState $ xs ++ replicate (8-k-1) 0 ++ [1] ++ drop 8 (elems arr) else listToState $ xs ++ drop k (elems arr) --- | Compression function for Merkle trees -compress :: Digest -> Digest -> Digest -compress x y = extractDigest $ permutation $ listToState s0 where - s0 = digestToList x ++ digestToList y ++ [0,0,0,0] - -------------------------------------------------------------------------------- diff --git a/src/Misc/Aux.hs b/src/Misc/Aux.hs index 393af31..2f835b6 100644 --- a/src/Misc/Aux.hs +++ b/src/Misc/Aux.hs @@ -14,18 +14,46 @@ import Data.Aeson hiding ( Array , pairs ) import GHC.Generics -------------------------------------------------------------------------------- +-- * Log2 +-- | The base 2 logarithm of an integer newtype Log2 = Log2 Int deriving newtype (Eq,Ord,Show,Num) +deriving instance Generic Log2 + +instance ToJSON Log2 where toJSON (Log2 x) = toJSON x +instance FromJSON Log2 where parseJSON y = Log2 <$> parseJSON y + fromLog2 :: Log2 -> Int fromLog2 (Log2 k) = k exp2 :: Log2 -> Int exp2 (Log2 k) = shiftL 1 k +exp2' :: Log2 -> Integer +exp2' (Log2 k) = shiftL 1 k + +safeLog2 :: Int -> Log2 +safeLog2 n = + if exp2 k == n + then k + else error "safeLog2: input is not a power of two" + where + k = floorLog2 n + +floorLog2 :: Int -> Log2 +floorLog2 = floorLog2' . fromIntegral + +floorLog2' :: Integer -> Log2 +floorLog2' = go where + go 0 = -1 + go 1 = 0 + go !x = 1 + go (shiftR x 1) + -------------------------------------------------------------------------------- +-- * Integers divCeil :: Int -> Int -> Int divCeil n k = div (n+k-1) k @@ -33,15 +61,16 @@ divCeil n k = div (n+k-1) k divFloor :: Int -> Int -> Int divFloor = div --------------------------------------------------------------------------------- +---------------------------------------- -range :: Int -> [Int] -range k = [0..k-1] +isEven :: Int -> Bool +isEven n = (n .&. 1) == 0 -range' :: Int -> Int -> [Int] -range' a b = [a..b-1] +isOdd :: Int -> Bool +isOdd n = (n .&. 1) /= 0 -------------------------------------------------------------------------------- +-- * Lists -- | Consecutive pairs of a list pairs :: [a] -> [(a,a)] @@ -58,6 +87,18 @@ safeZipWith f = go where go [] [] = [] go _ _ = error "safeZipWith: different input lengths" +safeZipWith3 :: (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d] +safeZipWith3 f = go where + go (x:xs) (y:ys) (z:zs) = f x y z : go xs ys zs + go [] [] [] = [] + go _ _ _ = error "safeZipWith3: different input lengths" + +safeZipWith4 :: (a -> b -> c -> d -> e) -> [a] -> [b] -> [c] -> [d] -> [e] +safeZipWith4 f = go where + go (x:xs) (y:ys) (z:zs) (w:ws) = f x y z w : go xs ys zs ws + go [] [] [] [] = [] + go _ _ _ _ = error "safeZipWith4: different input lengths" + longZipWith :: a -> b -> (a -> b -> c) -> [a] -> [b] -> [c] longZipWith x0 y0 f = go where go [] [] = [] @@ -82,6 +123,7 @@ remove1 :: [a] -> [[a]] remove1 = map snd . select1 -------------------------------------------------------------------------------- +-- * Arrays listToArray :: [a] -> Array Int a listToArray xs = listArray (0, length xs - 1) xs @@ -90,6 +132,15 @@ arrayLength :: Array Int a -> Int arrayLength arr = let (a,b) = bounds arr in b-a+1 -------------------------------------------------------------------------------- +-- * ranges + +range :: Int -> [Int] +range k = [0..k-1] + +range' :: Int -> Int -> [Int] +range' a b = [a..b-1] + +---------------------------------------- -- | The interval @[a,b)@ (inclusive on the left, exclusive on the right) data Range = MkRange diff --git a/src/Plonk/FRI.hs b/src/Plonk/FRI.hs new file mode 100644 index 0000000..79c2cf2 --- /dev/null +++ b/src/Plonk/FRI.hs @@ -0,0 +1,410 @@ + +-- | Verify the FRI protocol + +{-# LANGUAGE BangPatterns, StrictData, RecordWildCards, DeriveFunctor, DeriveFoldable #-} +module Plonk.FRI where + +-------------------------------------------------------------------------------- + +import Data.Array +import Data.Bits +import Data.Word +import Data.List +import Data.Foldable + +import Algebra.Goldilocks +import Algebra.GoldilocksExt +import Algebra.FFT + +import Challenge.FRI +import Challenge.Verifier + +import Hash.Digest +import Hash.Sponge +import Hash.Merkle + +import Types +import Misc.Aux + +{- +-- debugging only +import Text.Printf +import Challenge.Verifier +import Debug.Trace +debug !msg !x y = trace (">>> " ++ msg ++ ": " ++ show x) y +-} + +-------------------------------------------------------------------------------- + +data Oracles a = MkOracles + { oracle_constants :: a + , oracle_witness :: a + , oracle_pp_lookup :: a + , oracle_quotient :: a + } + deriving (Show,Functor,Foldable) + +enumerateOracles :: Oracles a -> [a] +enumerateOracles (MkOracles{..}) = + [ oracle_constants + , oracle_witness + , oracle_pp_lookup + , oracle_quotient + ] + +-- | Size of the 4 oracle matrices +oracleWidths :: CommonCircuitData -> Oracles Int +oracleWidths (MkCommonCircuitData{..}) = widths where + MkCircuitConfig{..} = circuit_config + widths = MkOracles + { oracle_constants = circuit_num_constants + config_num_routed_wires + , oracle_witness = config_num_wires + , oracle_pp_lookup = r * (1 + circuit_num_partial_products + circuit_num_lookup_polys) + , oracle_quotient = r * circuit_quotient_degree_factor + } + r = config_num_challenges + +buildListOracle :: Oracles Int -> [[a]] -> Oracles [a] +buildListOracle (MkOracles lc lw lp lq) [c,w,p,q] + = if (length c == lc) && + (length w == lw) && + (length p == lp) && + (length q == lq) + then MkOracles c w p q + else error "buildListOracle: list size do not match the expected" +buildListOracle _ _ = error "buildListOracle: expecting a list of 4 lists" + +---------------------------------------- + +validateMerkleCapLength :: Log2 -> MerkleCap -> MerkleCap +validateMerkleCapLength height cap@(MkMerkleCap roots) + | ok = cap + | otherwise = error "validateMerkleCapLength: cap has wrong size" + where + ok = length roots == len + len = exp2 height + +toMerkleOracles :: VerifierCircuitData -> Proof -> Oracles MerkleCap +toMerkleOracles (MkVerifierCircuitData{..}) (MkProof{..}) = oracles where + MkCommonCircuitData{..} = verifier_common + MkFriParams{..} = circuit_fri_params + validate = validateMerkleCapLength (fri_cap_height fri_config) + oracles = MkOracles + { oracle_constants = validate $ constants_sigmas_cap verifier_only + , oracle_witness = validate $ wires_cap + , oracle_pp_lookup = validate $ plonk_zs_partial_products_cap + , oracle_quotient = validate $ quotient_polys_cap + } + +-------------------------------------------------------------------------------- +-- * Initial tree proofs + +-- | Checks the initial tree proofs, and returns +-- evaluation oracles at @x = g * (eta ^ query_idx)@ +-- (it's just some rearrangement...) +checkInitialTreeProofs :: CommonCircuitData -> Oracles MerkleCap -> FExt -> Int -> FriInitialTreeProof -> Oracles [F] +checkInitialTreeProofs common_data oracles alpha query_idx (MkFriInitialTreeProof{..}) + | length evals_proofs /= 4 = error "checkInitialTreeProofs: expecting 4 Merkle proofs for the 4 oracles" + | not merkle_are_ok = error "checkInitialTreeProofs: at least one Merkle proof failed" + | otherwise = result + where + merkle_are_ok = and + [ checkMerkleProof cap query_idx leaf proof + | (cap,(leaf,proof)) <- safeZip (enumerateOracles oracles) evals_proofs + ] + config = circuit_config common_data + widths = oracleWidths common_data + result = buildListOracle widths (map fst evals_proofs) + +-- | Combinations (with powers of alpha) of openings +data PrecomputedReducedOpenings = MkPrecomputedReducedOpenings + { sum_this_row :: FExt -- ^ sum over the openings of the full rows + , sum_next_row :: FExt -- ^ sum over the few openings we need from the \"next row\" + } + deriving Show + +---------------------------------------- + +precomputeReducedOpenings :: FExt -> FriOpenings -> PrecomputedReducedOpenings +precomputeReducedOpenings alpha (MkFriOpenings [one,two]) = result where + result = MkPrecomputedReducedOpenings this next + this = reduceWithPowers alpha row1 + next = reduceWithPowers alpha row2 + MkFriOpeningBatch row1 = one + MkFriOpeningBatch row2 = two + +-- | Calculates the evaluation of the \"combined polynomial\" at @x0 = g * eta^query_idx@ +-- +-- More precisely, this is +-- +-- > G0(X) - Y0 G1(X) - Y1 +-- > ------------ + alpha^M * ------------------ +-- > X - zeta X - omega*zeta +-- +-- where (Y0,Y1) are the \"precomputed reduced openings\", +-- G0(X) and G1(X) are the column polynomial "batches" combined by powers of @alpha@, +-- and M is the size of the first batch. Finally @X -> x0@ is substituted. +-- +-- The first batch contains all columns, the second only +-- "zs" and "lookup_zs". +-- +combineInitial :: CommonCircuitData -> ProofChallenges -> PrecomputedReducedOpenings -> Oracles [F] -> Int -> FExt +combineInitial (MkCommonCircuitData{..}) (MkProofChallenges{..}) preComp oracles@(MkOracles{..}) query_idx + | sanityCheck = result + | otherwise = error "combineInitial: sanity check failed" + where + + MkCircuitConfig{..} = circuit_config + MkFriChallenges{..} = fri_challenges + + MkPrecomputedReducedOpenings y0 y1 = preComp + + zeta = plonk_zeta + alpha = fri_alpha + + r = config_num_challenges + npp = divCeil config_num_routed_wires circuit_quotient_degree_factor + + sanityCheck = r * (npp + circuit_num_lookup_polys) == length oracle_pp_lookup + + (oracle_pp,oracle_lookup) = splitAt (r*npp) oracle_pp_lookup + + -- NOTE: this is /reordered/ the same way as FriOpenings, + -- except that we don't have the same Openings input structure + -- here to reuse... + -- + -- the whole Plonky2 codebase is seriously full of WTF-ness + firstBatch + = oracle_constants + ++ oracle_witness + ++ oracle_pp + ++ oracle_quotient + ++ oracle_lookup + secondBatch + = take r oracle_pp + ++ oracle_lookup + + len_1st_batch = length firstBatch + len_2nd_batch = length secondBatch + + g0 = reduceWithPowers alpha (map fromBase firstBatch ) + g1 = reduceWithPowers alpha (map fromBase secondBatch) + + logn_small = fri_degree_bits circuit_fri_params + logn_lde = fri_LDE_bits circuit_fri_params + omega = subgroupGenerator logn_small + eta = subgroupGenerator logn_lde + + rev_idx = reverseBitsInt logn_lde query_idx + point_x = fromBase (mulGen * pow_ eta rev_idx) + + loc0 = zeta + loc1 = fromBase omega * zeta + + one = (g0 - y0) / (point_x - loc0) + two = (g1 - y1) / (point_x - loc1) + + result = powExt_ alpha len_2nd_batch * one + two + +-------------------------------------------------------------------------------- +-- * Proof-of-work + +checkProofOfWork :: FriConfig -> FriChallenges -> Bool +checkProofOfWork (MkFriConfig{..}) (MkFriChallenges{..}) = ok where + lo_mask = fromInteger (exp2' fri_proof_of_work_bits - 1) :: Word64 + mask = shiftL lo_mask (64 - fromLog2 fri_proof_of_work_bits) + ok = (fromF fri_pow_response .&. mask) == 0 + +-------------------------------------------------------------------------------- +-- * Folding + +-- | Note: query indices index into the bit-reversed-order arrays!! +data QueryIndex = MkQueryIndex + { query_array_size :: Log2 + , query_index_rev :: Int + } + deriving (Eq,Ord,Show) + +queryLocation :: F -> QueryIndex -> F +queryLocation shift (MkQueryIndex arr_size idx_rev) = loc where + loc = shift * pow_ eta (reverseBitsInt arr_size idx_rev) + eta = subgroupGenerator arr_size + +foldQueryIdx :: Log2 -> QueryIndex -> QueryIndex +foldQueryIdx arityLog2@(Log2 arity_bits) (MkQueryIndex oldSize oldIdx) = MkQueryIndex newSize newIdx where + newSize = oldSize - arityLog2 + newIdx = shiftR oldIdx arity_bits + +-- | A coset of size @2^arity@, which is the unit we fold into a single field extension element +-- These are the leaves of the FRI commit phase Merkle trees +data Coset = MkCoset + { coset_size :: Log2 -- ^ logarithm of the size of the coset + , coset_offset :: F -- ^ the coset is shifted from the subgroup by this element + , coset_values :: [FExt] -- ^ values of a polynomial on this coset + } + deriving Show + +-- | Handling some of the fucked up conventions +prepareCoset :: F -> QueryIndex -> [FExt] -> Coset +prepareCoset shift (MkQueryIndex bigLog2 idx) values = coset where + smallLog2@(Log2 arity) = safeLog2 (length values) + ofs = shift * pow_ eta start + start = reverseBitsInt bigLog2 + $ shiftL (shiftR idx arity) arity + eta = subgroupGenerator bigLog2 + coset = MkCoset + { coset_size = smallLog2 + , coset_offset = ofs + , coset_values = reverseIndexBitsList values + } + +-- | \"Folds\" a coset with a given folding coefficient @beta@ +-- +foldCosetWith :: FExt -> Coset -> FExt -- (QueryIndex,Coset) -> (QueryIndex,FExt) +--foldCosetWith beta (oldQueryIdx,coset) = (newQueryIdx,final) where +foldCosetWith beta coset = final where + MkCoset arity_bits coset_x_loc xs = coset + -- MkQueryIdx oldSize revIdx = oldQueryIdx + -- newQueryIdx = foldQueryIdx arity_bits oldQueryIdx + arity = exp2 arity_bits + omega = subgroupGenerator arity_bits + invArity = (1 :: F) / fromIntegral arity + ys = [ sum + [ scaleExt (pow_ x_omega_j (-k)) (xs!!j) + | j <- [0..arity-1] + , let x_omega_j = coset_x_loc * pow_ omega j + ] + | k <- [0..arity-1] + ] + final = scaleExt invArity $ sum $ zipWith (*) (powersOf beta) ys + +data FoldingState = MkFoldingState + { folding_shift :: F + , folding_query_idx :: QueryIndex + , folding_upstream_eval :: FExt + } + deriving Show + +folding_query_loc :: FoldingState -> F +folding_query_loc (MkFoldingState shift (MkQueryIndex log2n idx) _eval) = loc where + loc = shift * pow_ eta (reverseBitsInt log2n idx) + eta = subgroupGenerator log2n + +data FoldingInput = MkFoldingInput + { folding_arity :: Log2 + , folding_beta :: FExt + , folding_merkle_cap :: MerkleCap + , folding_query_proof :: FriQueryStep + } + deriving Show + +foldAll :: FoldingState -> [FoldingInput] -> FoldingState +foldAll = foldl' foldingStep + +-- | \"Folds\" a coset with a given folding coefficient @beta@ +-- +foldingStep :: FoldingState -> FoldingInput -> FoldingState +foldingStep + (MkFoldingState oldShift oldQueryIdx oldEval) + (MkFoldingInput arityLog2 beta merkleCap (MkFriQueryStep evals proof)) + | not proofCheckOK = error "folding step Merkle proof does not check out" + | not evalCheckOK = error "folding step evaluation does not match the opening" + | not arityCheckOK = error "folding stpe: reduction strategy incompatibility" + | otherwise = MkFoldingState newShift newQueryIdx newEval + where + arityCheckOK = arityLog2 == safeLog2 (length evals) + proofCheckOK = checkMerkleProof merkleCap (query_index_rev newQueryIdx) (flattenExt evals) proof + evalCheckOK = evals !! (query_index_rev oldQueryIdx `mod` arity) == oldEval + newShift = pow_ oldShift arity + arity = exp2 arityLog2 + coset = prepareCoset oldShift oldQueryIdx evals + -- (newQueryIdx,newEval) = foldCosetWith beta (oldQueryIdx,coset) + newQueryIdx = foldQueryIdx arityLog2 oldQueryIdx + newEval = foldCosetWith beta coset + +evalPolynomialAt :: Num coeff => PolynomialCoeffs coeff -> coeff -> coeff +evalPolynomialAt (MkPolynomialCoeffs coeffs) loc = value where + value = sum $ zipWith (*) coeffs (powersOf loc) + +-------------------------------------------------------------------------------- + +data FriStepParams = MkFriStepParams + { step_input_degree :: Log2 + , step_arity :: Log2 + } + deriving Show + +expandReductionStrategy :: Log2 -> FriReductionStrategy -> [FriStepParams] +expandReductionStrategy degree_logn strategy = + case strategy of + ConstantArityBits arity_bits final_poly_bits -> expandConstantStrategy arity_bits final_poly_bits degree_logn + Fixed arities -> expandFixedStrategy arities degree_logn + _ -> error "reduction strategy not implemented" + where + + expandConstantStrategy :: Log2 -> Log2 -> Log2 -> [FriStepParams] + expandConstantStrategy arity_bits final_poly_bits = go where + go logn = if logn <= final_poly_bits + then [] + else (MkFriStepParams logn arity_bits) : go (logn - arity_bits) + + expandFixedStrategy :: [Log2] -> Log2 -> [FriStepParams] + expandFixedStrategy = go where + go [] _ = [] + go (a:as) logn = (MkFriStepParams logn a) : go as (logn - a) + +-------------------------------------------------------------------------------- + +checkFRIProof :: VerifierCircuitData -> Proof -> ProofChallenges -> Bool +checkFRIProof vkey@(MkVerifierCircuitData{..}) proof@(MkProof{..}) challenges = ok where + + MkProofChallenges{..} = challenges + MkCommonCircuitData{..} = common + MkFriChallenges{..} = fri_challenges + fri_proof@(MkFriProof{..}) = opening_proof + + common = verifier_common + fri_config = config_fri_config circuit_config + fri_openings = toFriOpenings openings + + ok = pow_ok && and oks + pow_ok = checkProofOfWork fri_config fri_challenges + oks = safeZipWith checkQueryRound fri_query_indices fri_query_round_proofs + + merkleOracles = toMerkleOracles vkey proof + precomp = precomputeReducedOpenings fri_alpha fri_openings + logn_lde = fri_LDE_bits circuit_fri_params + logn_degree = fri_degree_bits circuit_fri_params + steps_params = expandReductionStrategy logn_degree (fri_reduction_strategy fri_config) + + checkQueryRound :: Int -> FriQueryRound -> Bool + checkQueryRound idx this_round@(MkFriQueryRound{..}) = round_ok where + + row_evals = checkInitialTreeProofs common merkleOracles fri_alpha idx fri_initial_trees_proof + combined_eval = combineInitial common challenges precomp row_evals idx + + initialState = MkFoldingState + { folding_shift = mulGen + , folding_query_idx = MkQueryIndex logn_lde idx + , folding_upstream_eval = combined_eval + } + + foldingInputs = safeZipWith4 + (\step_params beta cap step -> MkFoldingInput (step_arity step_params) beta cap step) + steps_params + fri_betas + fri_commit_phase_merkle_caps + fri_steps + + -- note: foldingStep actually checks the steps, but for simplicity + -- we just throw an exception if the checks fail. + -- TODO: maybe better error handling (though I'm fed up with this shit) + finalState = foldl' foldingStep initialState foldingInputs + + x_final = folding_query_loc finalState + final_poly_eval = evalPolynomialAt fri_final_poly (fromBase x_final) + + round_ok = final_poly_eval == folding_upstream_eval finalState + +-------------------------------------------------------------------------------- + diff --git a/src/Plonk/Verifier.hs b/src/Plonk/Verifier.hs index 75cd095..3f55c52 100644 --- a/src/Plonk/Verifier.hs +++ b/src/Plonk/Verifier.hs @@ -18,6 +18,7 @@ import Algebra.GoldilocksExt import Challenge.Verifier import Plonk.Vanishing +import Plonk.FRI import Hash.Digest @@ -52,3 +53,15 @@ checkCombinedPlonkEquations' common proof_pis challenges = ok_list where -------------------------------------------------------------------------------- +verifyProof :: VerifierCircuitData -> ProofWithPublicInputs -> Bool +verifyProof vkey@(MkVerifierCircuitData{..}) pis@(MkProofWithPublicInputs{..}) = all_ok where + + common = verifier_common + challenges = proofChallenges common verifier_only pis + + all_ok = eqs_ok && fri_ok + + eqs_ok = checkCombinedPlonkEquations common pis challenges + fri_ok = checkFRIProof vkey the_proof challenges + +-------------------------------------------------------------------------------- diff --git a/src/Types.hs b/src/Types.hs index c84b517..7f353b1 100644 --- a/src/Types.hs +++ b/src/Types.hs @@ -6,6 +6,7 @@ module Types where import Data.Char import Data.Word +import Data.List import Data.Aeson import qualified Data.Aeson.KeyMap as KeyMap @@ -34,6 +35,8 @@ fromLookupTable (MkLookupTable pairs) = [ (toF inp, toF out) | (inp,out) <- pair instance ToJSON LookupTable where toJSON (MkLookupTable x) = toJSON x instance FromJSON LookupTable where parseJSON o = MkLookupTable <$> parseJSON o +---------------------------------------- + newtype PolynomialCoeffs coeff = MkPolynomialCoeffs { coeffs :: [coeff] } deriving (Eq,Show,Generic,ToJSON,FromJSON) @@ -72,7 +75,7 @@ data CircuitConfig = MkCircuitConfig , config_num_routed_wires :: Int -- ^ The number of routed wires, i.e. wires that will be involved in Plonk's permutation argument. , config_num_constants :: Int -- ^ The number of constants that can be used per gate. , config_use_base_arithmetic_gate :: Bool -- ^ Whether to use a dedicated gate for base field arithmetic, rather than using a single gate for both base field and extension field arithmetic. - , config_security_bits :: Int -- ^ Security level target + , config_security_bits :: Log2 -- ^ Security level target , config_num_challenges :: Int -- ^ The number of challenge points to generate, for IOPs that have soundness errors of (roughly) `degree / |F|`. , config_zero_knowledge :: Bool -- ^ Option to activate the zero-knowledge property. , config_randomize_unused_wires :: Bool -- ^ Option to disable randomization (useful for debugging). @@ -110,10 +113,10 @@ instance ToJSON SelectorsInfo where -------------------------------------------------------------------------------- -- * FRI types -data FriConfig = MkFrConfig - { fri_rate_bits :: Int -- ^ @rate = 2^{-rate_bits}@ - , fri_cap_height :: Int -- ^ Height of Merkle tree caps. - , fri_proof_of_work_bits :: Int -- ^ Number of bits used for grinding. +data FriConfig = MkFriConfig + { fri_rate_bits :: Log2 -- ^ @rate = 2^{-rate_bits}@ + , fri_cap_height :: Log2 -- ^ Height of Merkle tree caps. + , fri_proof_of_work_bits :: Log2 -- ^ Number of bits used for grinding. , fri_reduction_strategy :: FriReductionStrategy -- ^ The reduction strategy to be applied at each layer during the commit phase. , fri_num_query_rounds :: Int -- ^ Number of query rounds to perform. } @@ -123,9 +126,9 @@ instance FromJSON FriConfig where parseJSON = genericParseJSON defaultOptions { instance ToJSON FriConfig where toJSON = genericToJSON defaultOptions { fieldLabelModifier = drop 4 } data FriReductionStrategy - = Fixed { arity_bits_seq :: [Int] } - | ConstantArityBits { arity_bits :: Int , final_poly_bits :: Int } - | MinSize { opt_max_arity_bits :: Maybe Int } + = Fixed { arity_bits_seq :: [Log2] } + | ConstantArityBits { arity_bits :: Log2 , final_poly_bits :: Log2 } + | MinSize { opt_max_arity_bits :: Maybe Log2 } deriving (Eq,Show,Generic) instance FromJSON FriReductionStrategy where @@ -148,16 +151,25 @@ instance ToJSON FriReductionStrategy where data FriParams = MkFriParams { fri_config :: FriConfig -- ^ User-specified FRI configuration. , fri_hiding :: Bool -- ^ Whether to use a hiding variant of Merkle trees (where random salts are added to leaves). - , fri_degree_bits :: Int -- ^ The degree of the purported codeword, measured in bits. - , fri_reduction_arity_bits :: [Int] -- ^ The arity of each FRI reduction step, expressed as the log2 of the actual arity. + , fri_degree_bits :: Log2 -- ^ The degree of the purported codeword, measured in bits. + , fri_reduction_arity_bits :: [Log2] -- ^ The arity of each FRI reduction step, expressed as the log2 of the actual arity. } deriving (Eq,Show,Generic) -- | Number of rows in the circuit fri_nrows :: FriParams -> Int -fri_nrows params = 2^nbits where +fri_nrows params = exp2 nbits where nbits = fri_degree_bits params +-- | Logarithm of the size of the LDE codeword +fri_LDE_bits :: FriParams -> Log2 +fri_LDE_bits params = nbits where + nbits = fri_degree_bits params + fri_rate_bits (fri_config params) + +-- | Number of rows in the LDE codewords +fri_LDE_nrows :: FriParams -> Int +fri_LDE_nrows params = exp2 (fri_LDE_bits params) + instance FromJSON FriParams where parseJSON = genericParseJSON defaultOptions { fieldLabelModifier = drop 4 } instance ToJSON FriParams where toJSON = genericToJSON defaultOptions { fieldLabelModifier = drop 4 } @@ -215,6 +227,9 @@ newtype MerkleCap = MkMerkleCap [Digest] deriving (Eq,Show,Generic) +merkleCapSize :: MerkleCap -> Int +merkleCapSize (MkMerkleCap ds) = length ds + instance ToJSON MerkleCap where toJSON (MkMerkleCap caps) = toJSON caps instance FromJSON MerkleCap where parseJSON o = MkMerkleCap <$> parseJSON o diff --git a/src/testmain.hs b/src/testmain.hs index 6a70fc3..69fbe47 100644 --- a/src/testmain.hs +++ b/src/testmain.hs @@ -11,9 +11,11 @@ import Hash.Sponge import Hash.Digest import Algebra.Goldilocks import Challenge.Verifier -import Plonk.Vanishing import Plonk.Verifier +import Plonk.Vanishing +import Plonk.FRI + import qualified Data.ByteString.Char8 as B import qualified Data.ByteString.Lazy.Char8 as L @@ -21,9 +23,10 @@ import qualified Data.ByteString.Lazy.Char8 as L main = do -- let prefix = "fibonacci" - -- let prefix = "recursion_outer" -- let prefix = "lookup" - let prefix = "multi_lookup" + -- let prefix = "multi_lookup" + -- let prefix = "recursion_middle" + let prefix = "recursion_outer" text_common <- L.readFile ("../json/" ++ prefix ++ "_common.json") text_vkey <- L.readFile ("../json/" ++ prefix ++ "_vkey.json" ) @@ -32,7 +35,8 @@ main = do let Just common_data = decode text_common :: Maybe CommonCircuitData let Just verifier_data = decode text_vkey :: Maybe VerifierOnlyCircuitData let Just proof_data = decode text_proof :: Maybe ProofWithPublicInputs - + let vkey = MkVerifierCircuitData verifier_data common_data + let pi_hash = sponge (public_inputs proof_data) putStrLn $ "public inputs hash = " ++ show pi_hash @@ -53,3 +57,7 @@ main = do print $ evalCombinedPlonkConstraints common_data proof_data challenges print $ checkCombinedPlonkEquations' common_data proof_data challenges + + -- debugFRI common_data verifier_data (the_proof proof_data) challenges + + putStrLn $ "proof verification result = " ++ show (verifyProof vkey proof_data)