From 2a3f5e895ad3ef8e579ccb70e05e82a6ffcdb3a2 Mon Sep 17 00:00:00 2001 From: Balazs Komuves Date: Fri, 24 Jan 2025 11:14:19 +0100 Subject: [PATCH] refactor constant column handling --- src/Gate/Computation.hs | 11 ++++--- src/Gate/Selector.hs | 72 +++++++++++++++++++++++++++++++++++++---- src/Gate/Vars.hs | 18 ++++++----- src/Plonk/Vanishing.hs | 24 +++++++------- src/Types.hs | 5 ++- src/testmain.hs | 2 +- 6 files changed, 100 insertions(+), 32 deletions(-) diff --git a/src/Gate/Computation.hs b/src/Gate/Computation.hs index f44ad6c..8481ff4 100644 --- a/src/Gate/Computation.hs +++ b/src/Gate/Computation.hs @@ -176,6 +176,7 @@ type Constraint = Expr_ -- Typically this will be the evaluations of the column polynomials at @zeta@ data EvaluationVars a = MkEvaluationVars { local_selectors :: Array Int a -- ^ the selectors + , local_lkp_sels :: Array Int a -- ^ the lookup selectors , local_constants :: Array Int a -- ^ the circuit constants , local_wires :: Array Int a -- ^ the advice wires (witness) , public_inputs_hash :: [F] -- ^ only used in @PublicInputGate@ @@ -186,6 +187,7 @@ data EvaluationVars a = MkEvaluationVars testEvaluationVarsBase :: EvaluationVars F testEvaluationVarsBase = MkEvaluationVars { local_selectors = listArray (0, 0) [] + , local_lkp_sels = listArray (0, 0) [] , local_constants = listArray (0, 1) [666,77] , local_wires = listArray (0,134) [ 1001 + 71 * fromInteger i | i<-[0..134] ] , public_inputs_hash = [101,102,103,104] @@ -202,10 +204,11 @@ evalConstraint scope (MkEvaluationVars{..}) expr = evalExprWith f expr where Just y -> y Nothing -> error $ "variable _" ++ n ++ show i ++ " not in scope" ProofVar v -> case v of - SelV k -> local_selectors ! k - ConstV k -> local_constants ! k - WireV k -> local_wires ! k - PIV k -> fromBase (public_inputs_hash !! k) + SelV k -> local_selectors ! k + LkpSelV k -> local_lkp_sels ! k + ConstV k -> local_constants ! k + WireV k -> local_wires ! k + PIV k -> fromBase (public_inputs_hash !! k) evalConstraints :: Scope FExt -> EvaluationVars FExt -> [Constraint] -> [FExt] evalConstraints scope vars = map (evalConstraint scope vars) diff --git a/src/Gate/Selector.hs b/src/Gate/Selector.hs index d07b4b2..a44859f 100644 --- a/src/Gate/Selector.hs +++ b/src/Gate/Selector.hs @@ -20,13 +20,68 @@ import Misc.Aux -------------------------------------------------------------------------------- +data SelectorConfig = MkSelectorConfig + { numGateSelectors :: Int -- ^ number of gate selectors (usually 2-3) + , numLookupSelectors :: Int -- ^ number of lookup selectors (`4 + #nluts`) + , numGateConstants :: Int -- ^ number of gate constants (normally 2) + , numSigmaColumns :: Int -- ^ number of sigma columns (normally 80) + } + deriving Show + +getSelectorConfig :: CommonCircuitData -> SelectorConfig +getSelectorConfig (MkCommonCircuitData{..}) + | circuit_num_lookup_selectors /= expected_lookup_sels + = error "getSelectorConfig: fatal: num_lookup_selectors /= (4 + #nluts)" + | circuit_num_constants /= num_gate_selectors + circuit_num_lookup_selectors + config_num_constants + = error "getSelectorConfig: fatal: constant columns tally does not add up!" + | otherwise = MkSelectorConfig + { numGateSelectors = num_gate_selectors + , numLookupSelectors = circuit_num_lookup_selectors + , numGateConstants = config_num_constants + , numSigmaColumns = config_num_routed_wires + } + where + MkCircuitConfig{..} = circuit_config + nluts = length circuit_luts + expected_lookup_sels = if nluts == 0 then 0 else (4 + nluts) + num_gate_selectors = length (selector_groups circuit_selectors_info) + + -- NOTE: + -- circuit_num_constants = total number of constant columns (selectors + lookup_selectors + constants) + -- config_num_constants = only the gate constants + +-------------------------------------------------------------------------------- + +data ConstantColumns a = MkConstantColumns + { gateSelectors :: [a] + , lookupSelectors :: [a] + , gateConstants :: [a] + } + deriving Show + +splitConstantColumns :: SelectorConfig -> [a] -> ConstantColumns a +splitConstantColumns (MkSelectorConfig{..}) xs + | not (null rest3) = error "splitConstantColumns: fatal: numbers do not add up" + | length konst /= numGateConstants = error "splitConstantColumns: fatal: not enough constant columns" + | otherwise = MkConstantColumns + { gateSelectors = gate_sel + , lookupSelectors = lkp_sel + , gateConstants = konst + } + where + (gate_sel,rest1) = splitAt numGateSelectors xs + (lkp_sel ,rest2) = splitAt numLookupSelectors rest1 + (konst ,rest3) = splitAt numGateConstants rest2 + +-------------------------------------------------------------------------------- + -- | Given an evaluation point @x@ and a gate index @k@, we compute -- the evaluation of the corresponding selector polynomial -- -- Note: In the actual protocol, we have @x = S_g(zeta)@ -- -evalSelectorPoly :: SelectorsInfo -> FExt -> Int -> FExt -evalSelectorPoly (MkSelectorsInfo{..}) x k = value where +evalGateSelectorPoly :: SelectorsInfo -> FExt -> Int -> FExt +evalGateSelectorPoly (MkSelectorsInfo{..}) x k = value where group_idx = selector_indices !! k range = selector_groups !! group_idx initial = if length selector_groups > 1 then unused - x else 1 @@ -35,11 +90,14 @@ evalSelectorPoly (MkSelectorsInfo{..}) x k = value where -- | Given the evaluations of the selector column polynomials, we evaluate -- all the gate selectors -evalSelectors :: SelectorsInfo -> [FExt] -> [FExt] -evalSelectors selInfo@(MkSelectorsInfo{..}) xs = values where - values = [ evalSelectorPoly selInfo (xs!!grp) i | (i,grp) <- zip [0..] selector_indices ] +evalGateSelectors :: SelectorsInfo -> [FExt] -> [FExt] +evalGateSelectors selInfo@(MkSelectorsInfo{..}) xs = values where + values = [ evalGateSelectorPoly selInfo (xs!!grp) i | (i,grp) <- zip [0..] selector_indices ] -numSelectorColumns :: SelectorsInfo -> Int -numSelectorColumns selInfo = length (selector_groups selInfo) +{- +-- | Number of /gate selector/ column (does not include the lookup selectors!) +numGateSelectorColumns :: SelectorsInfo -> Int +numGateSelectorColumns selInfo = length (selector_groups selInfo) +-} -------------------------------------------------------------------------------- diff --git a/src/Gate/Vars.hs b/src/Gate/Vars.hs index b0f96c4..1a51423 100644 --- a/src/Gate/Vars.hs +++ b/src/Gate/Vars.hs @@ -16,18 +16,20 @@ import Misc.Pretty -- | These index into a row + public input data PlonkyVar - = SelV Int -- ^ selector variable - | ConstV Int -- ^ constant variable - | WireV Int -- ^ wire variable - | PIV Int -- ^ public input hash variable (technically these are constants, not variables) + = SelV Int -- ^ selector variable + | LkpSelV Int -- ^ lookup selector variable + | ConstV Int -- ^ constant variable + | WireV Int -- ^ wire variable + | PIV Int -- ^ public input hash variable (technically these are constants, not variables) deriving (Eq,Ord,Show) instance Pretty PlonkyVar where prettyPrec _ v = case v of - SelV k -> showString ("s" ++ show k) - ConstV k -> showString ("c" ++ show k) - WireV k -> showString ("w" ++ show k) - PIV k -> showString ("h" ++ show k) + SelV k -> showString ("s" ++ show k) + LkpSelV k -> showString ("l" ++ show k) + ConstV k -> showString ("c" ++ show k) + WireV k -> showString ("w" ++ show k) + PIV k -> showString ("h" ++ show k) -------------------------------------------------------------------------------- -- * Variables diff --git a/src/Plonk/Vanishing.hs b/src/Plonk/Vanishing.hs index b0db24e..6d05219 100644 --- a/src/Plonk/Vanishing.hs +++ b/src/Plonk/Vanishing.hs @@ -64,7 +64,7 @@ combineWithPowersOfAlpha alpha xs = foldl' f 0 (reverse xs) where evalAllPlonkConstraints :: CommonCircuitData -> ProofWithPublicInputs -> ProofChallenges -> [FExt] evalAllPlonkConstraints - (MkCommonCircuitData{..}) + common_data@(MkCommonCircuitData{..}) (MkProofWithPublicInputs{..}) (MkProofChallenges{..}) = finals where @@ -81,17 +81,17 @@ evalAllPlonkConstraints MkProof{..} = the_proof MkOpeningSet{..} = openings - nselectors = numSelectorColumns circuit_selectors_info - opening_selectors = take nselectors opening_constants + MkSelectorConfig{..} = getSelectorConfig common_data + opening_gate_selectors = take numGateSelectors opening_constants nn = fri_nrows circuit_fri_params maxdeg = circuit_quotient_degree_factor pi_hash = sponge public_inputs - + -- gate constraints - eval_vars = toEvaluationVars pi_hash circuit_selectors_info openings + eval_vars = toEvaluationVars common_data pi_hash openings gate_prgs = map gateProgram circuit_gates - sel_values = evalSelectors circuit_selectors_info opening_selectors + sel_values = evalGateSelectors circuit_selectors_info opening_gate_selectors unfiltered = map (runStraightLine eval_vars) gate_prgs filtered = zipWith (\s cons -> map (*s) cons) sel_values unfiltered gates = combineFilteredGateConstraints filtered @@ -129,16 +129,18 @@ combineFilteredGateConstraints = foldl1 (longZipWith 0 0 (+)) -------------------------------------------------------------------------------- -toEvaluationVars :: Digest -> SelectorsInfo -> OpeningSet -> EvaluationVars FExt -toEvaluationVars pi_hash selinfo (MkOpeningSet{..}) = +toEvaluationVars :: CommonCircuitData -> Digest -> OpeningSet -> EvaluationVars FExt +toEvaluationVars common_data pi_hash (MkOpeningSet{..}) = MkEvaluationVars - { local_selectors = listToArray (take nsels opening_constants) - , local_constants = listToArray (drop nsels opening_constants) + { local_selectors = listToArray gateSelectors + , local_lkp_sels = listToArray lookupSelectors + , local_constants = listToArray gateConstants , local_wires = listToArray opening_wires , public_inputs_hash = digestToList pi_hash } where - nsels = numSelectorColumns selinfo + selcfg = getSelectorConfig common_data + MkConstantColumns{..} = splitConstantColumns selcfg opening_constants -------------------------------------------------------------------------------- diff --git a/src/Types.hs b/src/Types.hs index acc8789..a95405d 100644 --- a/src/Types.hs +++ b/src/Types.hs @@ -45,7 +45,7 @@ data CommonCircuitData = MkCommonCircuitData , circuit_selectors_info :: SelectorsInfo -- ^ Information on the circuit's selector polynomials. , circuit_quotient_degree_factor :: Int -- ^ The degree of the PLONK quotient polynomial. , circuit_num_gate_constraints :: Int -- ^ The largest number of constraints imposed by any gate. - , circuit_num_constants :: Int -- ^ The number of constant wires. + , circuit_num_constants :: Int -- ^ The number of constant columns wires. , circuit_num_public_inputs :: Int -- ^ Number of public inputs , circuit_k_is :: [F] -- ^ The @{k_i}@ values (coset shifts) used in @S_I D_i@ in Plonk's permutation argument. , circuit_num_partial_products :: Int -- ^ The number of partial products needed to compute the `Z` polynomials; @ = ceil( #routed / max_degree ) - 1@ @@ -58,6 +58,9 @@ data CommonCircuitData = MkCommonCircuitData circuit_nrows :: CommonCircuitData -> Int circuit_nrows = fri_nrows . circuit_fri_params +circuit_num_luts :: CommonCircuitData -> Int +circuit_num_luts = length . circuit_luts + instance FromJSON CommonCircuitData where parseJSON = genericParseJSON defaultOptions { fieldLabelModifier = drop 8 } --instance ToJSON CommonCircuitData where toJSON = genericToJSON defaultOptions { fieldLabelModifier = drop 8 } diff --git a/src/testmain.hs b/src/testmain.hs index 95c03ac..6ec4fd2 100644 --- a/src/testmain.hs +++ b/src/testmain.hs @@ -48,7 +48,7 @@ main = do let challenges = proofChallenges common_data verifier_data proof_data - print challenges + -- print challenges print $ evalCombinedPlonkConstraints common_data proof_data challenges print $ checkCombinedPlonkEquations' common_data proof_data challenges