From 4e88c0defbca18edb5884f5c86f4a01ea6d27fd3 Mon Sep 17 00:00:00 2001 From: Balazs Komuves Date: Thu, 23 Jan 2025 13:14:55 +0100 Subject: [PATCH] fix the exponentiation gate and implement random access gate --- README.md | 2 +- src/Gate/Computation.hs | 37 ++++++++++++---- src/Gate/Constraints.hs | 38 ++++++++++++----- src/Gate/RandomAccess.hs | 91 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 150 insertions(+), 18 deletions(-) create mode 100644 src/Gate/RandomAccess.hs diff --git a/README.md b/README.md index cf0564b..4263107 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ Supported gates: - [x] PublicInputGate - [x] PoseidonGate - [ ] PoseidonMdsGate -- [ ] RandomAccessGate +- [X] RandomAccessGate - [ ] ReducingGate - [ ] ReducingExtensionGate diff --git a/src/Gate/Computation.hs b/src/Gate/Computation.hs index 89f27aa..783697d 100644 --- a/src/Gate/Computation.hs +++ b/src/Gate/Computation.hs @@ -90,22 +90,29 @@ instance Pretty v => Pretty (LocalDef v) where -- | A straightline program encoding the computation of constraints data StraightLine = MkStraightLine - { localdefs :: [LocalDef Var_] -- ^ local definitions, in reverse order - , commits :: [Expr_] -- ^ committed constraints, in reverse order + { localdefs :: [LocalDef Var_] -- ^ local definitions (during compilation, in reverse order) + , commits :: [Expr_] -- ^ committed constraints (during compilation, in reverse order) , counter :: Int -- ^ fresh variable counter } deriving Show +reverseStraightLine :: StraightLine -> StraightLine +reverseStraightLine (MkStraightLine{..}) = MkStraightLine + { localdefs = reverse localdefs + , commits = reverse commits + , counter = counter + } + emptyStraightLine :: StraightLine emptyStraightLine = MkStraightLine [] [] 0 printStraightLine :: StraightLine -> IO () printStraightLine (MkStraightLine{..}) = do - forM_ (reverse localdefs) $ \def -> putStrLn (pretty def) - forM_ (reverse commits ) $ \expr -> putStrLn $ "constraint 0 == " ++ (pretty expr) + forM_ (localdefs) $ \def -> putStrLn (pretty def) + forM_ (commits ) $ \expr -> putStrLn $ "constraint 0 == " ++ (pretty expr) compileToStraightLine :: Compute () -> StraightLine -compileToStraightLine = fst . go emptyStraightLine where +compileToStraightLine = reverseStraightLine . fst . go emptyStraightLine where go :: StraightLine -> Compute a -> (StraightLine,a) go state instr = case instr of Return x -> (state,x) @@ -149,10 +156,13 @@ runStraightLine = runStraightLine' emptyScope runStraightLine' :: Scope FExt -> EvaluationVars FExt -> StraightLine -> [FExt] runStraightLine' iniScope vars (MkStraightLine{..}) = result where - finalScope = foldl' worker iniScope (reverse localdefs) - result = evalConstraints finalScope vars (reverse commits) + finalScope = foldl' worker iniScope (localdefs) + result = evalConstraints finalScope vars (commits) worker !scope (MkLocalDef i _ rhs) = IntMap.insert i (evalConstraint scope vars rhs) scope +runComputation :: EvaluationVars FExt -> Compute () -> [FExt] +runComputation evalvars action = runStraightLine evalvars (compileToStraightLine action) + -------------------------------------------------------------------------------- -- * Evaluation @@ -169,6 +179,19 @@ data EvaluationVars a = MkEvaluationVars } deriving (Show,Functor) +-- | used for testing the gate constraint +testEvaluationVarsBase :: EvaluationVars F +testEvaluationVarsBase = MkEvaluationVars + { local_selectors = listArray (0, 0) [] + , local_constants = listArray (0, 1) [666,77] + , local_wires = listArray (0,134) [ 1001 + 71 * fromInteger i | i<-[0..134] ] + , public_inputs_hash = [101,102,103,104] + } + +testEvaluationVarsExt :: EvaluationVars FExt +testEvaluationVarsExt = fmap f testEvaluationVarsBase where + f x = MkExt x 13 + evalConstraint :: Scope FExt -> EvaluationVars FExt -> Constraint -> FExt evalConstraint scope (MkEvaluationVars{..}) expr = evalExprWith f expr where f var = case var of diff --git a/src/Gate/Constraints.hs b/src/Gate/Constraints.hs index 18f6429..8373df1 100644 --- a/src/Gate/Constraints.hs +++ b/src/Gate/Constraints.hs @@ -22,6 +22,7 @@ import Gate.Base import Gate.Vars import Gate.Computation import Gate.Poseidon +import Gate.RandomAccess import Misc.Aux @@ -62,15 +63,7 @@ gateComputation gate = -- computes `out = base ^ (sum 2^i e_i)` -- order of witness variables: [ base, e[0],...,e[n-1], output, t[0]...t[n-1] ] - ExponentiationGate num_power_bits - -> let base = wire 0 - exp_bit i = wire (i+1) - out = wire (num_power_bits+1) - tmp_val 0 = 1 - tmp_val i = wire (num_power_bits+1+i) - cur_bit i = exp_bit (num_power_bits - 1 - i) - eq i = tmp_val (i-1) * (cur_bit i * base + 1 - cur_bit i) - tmp_val i - in commitList $ [ eq i | i <- range num_power_bits ] ++ [ out - tmp_val (num_power_bits-1) ] + ExponentiationGate num_power_bits -> exponentiationGateConstraints num_power_bits -- lookups are handled specially, no constraints here LookupGate num_slots lut_hash -> return () @@ -95,7 +88,7 @@ gateComputation gate = k -> error ( "gateConstraints/PoseidonMdsGate: unsupported width " ++ show k) RandomAccessGate num_bits num_copies num_extra_constants - -> todo + -> randomAccessGateConstraints (MkRACfg num_bits num_copies num_extra_constants) ReducingGate num_coeffs -> todo @@ -110,3 +103,28 @@ gateComputation gate = todo = error $ "gateConstraints: gate `" ++ takeWhile isAlpha (show gate) ++ "` not yet implemented" -------------------------------------------------------------------------------- + +-- computes `out = base ^ (sum 2^i e_i)` +-- order of witness variables: [ base, e[0],...,e[n-1], output, t[0]...t[n-1] ] +exponentiationGateConstraints :: Int -> Compute () +exponentiationGateConstraints num_power_bits = + do + let prev i = if i==0 then 1 else sqr (tmp_val (i-1)) + let comp i = prev i * (cur_bit i * base + (1 - cur_bit i)) + let eq i = comp i - tmp_val i + commitList [ eq i | i <- range num_power_bits ] + commit ( out - tmp_val (num_power_bits-1) ) + where + base = wire 0 + exp_bit i = wire (i+1) + out = wire (num_power_bits+1) + tmp_val i = wire (num_power_bits+2+i) + cur_bit i = exp_bit (num_power_bits - 1 - i) + sqr x = x*x + +-------------------------------------------------------------------------------- + +testExpoGate = runComputation testEvaluationVarsExt (gateComputation (ExponentiationGate 13)) + +-------------------------------------------------------------------------------- + diff --git a/src/Gate/RandomAccess.hs b/src/Gate/RandomAccess.hs new file mode 100644 index 0000000..98a27a8 --- /dev/null +++ b/src/Gate/RandomAccess.hs @@ -0,0 +1,91 @@ + +-- | The @RandomAccess@ gate + +{-# LANGUAGE StrictData, RecordWildCards #-} +module Gate.RandomAccess where + +-------------------------------------------------------------------------------- + +import Data.Foldable +import Control.Monad + +import Algebra.Goldilocks +import Algebra.Expr + +import Gate.Vars +import Gate.Computation + +import Misc.Aux + +-------------------------------------------------------------------------------- + +num_routed_columns = 80 :: Int + +-------------------------------------------------------------------------------- + +data RandomAccessGateConfig = MkRACfg + { ra_num_bits :: Int -- ^ number of bits in the index (so the vector has width @2^n@) + , ra_num_copies :: Int -- ^ how many copies of this operation is included in a row + , ra_num_extra_constants :: Int -- ^ number of extra cells used as in ConstantGate + } + deriving Show + +randomAccessGateConfig :: Int -> RandomAccessGateConfig +randomAccessGateConfig num_bits = ra_cfg where + veclen = 2 ^ num_bits + width = 2 + veclen + copies = Prelude.div num_routed_columns width + extra = min 2 (num_routed_columns - copies*width) + ra_cfg = MkRACfg + { ra_num_bits = num_bits + , ra_num_copies = copies + , ra_num_extra_constants = extra + } + +-------------------------------------------------------------------------------- + +randomAccessGateConstraints :: RandomAccessGateConfig -> Compute () +randomAccessGateConstraints (MkRACfg{..}) = do + + forM_ [0..ra_num_copies-1] $ \k -> do + + -- index bits are actual bits + commitList [ bits k j * (bits k j - 1) | j<-[0..ra_num_bits-1] ] + + -- bit decomposition is correct + let reconstr = foldr (\b acc -> 2*acc + b) 0 [ bits k j | j<-[0..ra_num_bits-1] ] + commit $ reconstr - index k + + let lkp_val = lookup_eq + [ bits k j | j<-[0..ra_num_bits-1] ] + [ inputs k i | i<-[0..veclen-1] ] + commit $ lkp_val - output k + + forM_ [0..ra_num_extra_constants-1] $ \j -> commit (cnst j - extra j) + + where + + veclen = 2 ^ ra_num_bits + width = 2 + veclen + + bits_start_at = width * ra_num_copies + ra_num_extra_constants + + -- witness variables + index k = wire $ k*width + 0 + output k = wire $ k*width + 1 + inputs k j = wire $ k*width + 2 + j + extra j = wire $ ra_num_copies * width + j + bits k j = wire $ bits_start_at + k*ra_num_bits + j + + into_pairs [] = [] + into_pairs (x:y:rest) = (x,y) : into_pairs rest + into_pairs [x] = error "into_pairs: odd input" + + lookup_eq [] [z] = z + lookup_eq (b:bits) values = lookup_eq bits $ map (\(x,y) -> x + b*(y-x)) (into_pairs values) + lookup_eq bits values = error $ "RandomAccessGate/lookup_eq: shouldn't happen: " ++ show (length bits, length values) + +-------------------------------------------------------------------------------- + +testRAGate = runComputation testEvaluationVarsExt (randomAccessGateConstraints $ randomAccessGateConfig 4) +