fix the exponentiation gate and implement random access gate

This commit is contained in:
Balazs Komuves 2025-01-23 13:14:55 +01:00
parent 39c7316be0
commit 4e88c0defb
No known key found for this signature in database
GPG Key ID: F63B7AEF18435562
4 changed files with 150 additions and 18 deletions

View File

@ -45,7 +45,7 @@ Supported gates:
- [x] PublicInputGate
- [x] PoseidonGate
- [ ] PoseidonMdsGate
- [ ] RandomAccessGate
- [X] RandomAccessGate
- [ ] ReducingGate
- [ ] ReducingExtensionGate

View File

@ -90,22 +90,29 @@ instance Pretty v => Pretty (LocalDef v) where
-- | A straightline program encoding the computation of constraints
data StraightLine = MkStraightLine
{ localdefs :: [LocalDef Var_] -- ^ local definitions, in reverse order
, commits :: [Expr_] -- ^ committed constraints, in reverse order
{ localdefs :: [LocalDef Var_] -- ^ local definitions (during compilation, in reverse order)
, commits :: [Expr_] -- ^ committed constraints (during compilation, in reverse order)
, counter :: Int -- ^ fresh variable counter
}
deriving Show
reverseStraightLine :: StraightLine -> StraightLine
reverseStraightLine (MkStraightLine{..}) = MkStraightLine
{ localdefs = reverse localdefs
, commits = reverse commits
, counter = counter
}
emptyStraightLine :: StraightLine
emptyStraightLine = MkStraightLine [] [] 0
printStraightLine :: StraightLine -> IO ()
printStraightLine (MkStraightLine{..}) = do
forM_ (reverse localdefs) $ \def -> putStrLn (pretty def)
forM_ (reverse commits ) $ \expr -> putStrLn $ "constraint 0 == " ++ (pretty expr)
forM_ (localdefs) $ \def -> putStrLn (pretty def)
forM_ (commits ) $ \expr -> putStrLn $ "constraint 0 == " ++ (pretty expr)
compileToStraightLine :: Compute () -> StraightLine
compileToStraightLine = fst . go emptyStraightLine where
compileToStraightLine = reverseStraightLine . fst . go emptyStraightLine where
go :: StraightLine -> Compute a -> (StraightLine,a)
go state instr = case instr of
Return x -> (state,x)
@ -149,10 +156,13 @@ runStraightLine = runStraightLine' emptyScope
runStraightLine' :: Scope FExt -> EvaluationVars FExt -> StraightLine -> [FExt]
runStraightLine' iniScope vars (MkStraightLine{..}) = result where
finalScope = foldl' worker iniScope (reverse localdefs)
result = evalConstraints finalScope vars (reverse commits)
finalScope = foldl' worker iniScope (localdefs)
result = evalConstraints finalScope vars (commits)
worker !scope (MkLocalDef i _ rhs) = IntMap.insert i (evalConstraint scope vars rhs) scope
runComputation :: EvaluationVars FExt -> Compute () -> [FExt]
runComputation evalvars action = runStraightLine evalvars (compileToStraightLine action)
--------------------------------------------------------------------------------
-- * Evaluation
@ -169,6 +179,19 @@ data EvaluationVars a = MkEvaluationVars
}
deriving (Show,Functor)
-- | used for testing the gate constraint
testEvaluationVarsBase :: EvaluationVars F
testEvaluationVarsBase = MkEvaluationVars
{ local_selectors = listArray (0, 0) []
, local_constants = listArray (0, 1) [666,77]
, local_wires = listArray (0,134) [ 1001 + 71 * fromInteger i | i<-[0..134] ]
, public_inputs_hash = [101,102,103,104]
}
testEvaluationVarsExt :: EvaluationVars FExt
testEvaluationVarsExt = fmap f testEvaluationVarsBase where
f x = MkExt x 13
evalConstraint :: Scope FExt -> EvaluationVars FExt -> Constraint -> FExt
evalConstraint scope (MkEvaluationVars{..}) expr = evalExprWith f expr where
f var = case var of

View File

@ -22,6 +22,7 @@ import Gate.Base
import Gate.Vars
import Gate.Computation
import Gate.Poseidon
import Gate.RandomAccess
import Misc.Aux
@ -62,15 +63,7 @@ gateComputation gate =
-- computes `out = base ^ (sum 2^i e_i)`
-- order of witness variables: [ base, e[0],...,e[n-1], output, t[0]...t[n-1] ]
ExponentiationGate num_power_bits
-> let base = wire 0
exp_bit i = wire (i+1)
out = wire (num_power_bits+1)
tmp_val 0 = 1
tmp_val i = wire (num_power_bits+1+i)
cur_bit i = exp_bit (num_power_bits - 1 - i)
eq i = tmp_val (i-1) * (cur_bit i * base + 1 - cur_bit i) - tmp_val i
in commitList $ [ eq i | i <- range num_power_bits ] ++ [ out - tmp_val (num_power_bits-1) ]
ExponentiationGate num_power_bits -> exponentiationGateConstraints num_power_bits
-- lookups are handled specially, no constraints here
LookupGate num_slots lut_hash -> return ()
@ -95,7 +88,7 @@ gateComputation gate =
k -> error ( "gateConstraints/PoseidonMdsGate: unsupported width " ++ show k)
RandomAccessGate num_bits num_copies num_extra_constants
-> todo
-> randomAccessGateConstraints (MkRACfg num_bits num_copies num_extra_constants)
ReducingGate num_coeffs
-> todo
@ -110,3 +103,28 @@ gateComputation gate =
todo = error $ "gateConstraints: gate `" ++ takeWhile isAlpha (show gate) ++ "` not yet implemented"
--------------------------------------------------------------------------------
-- computes `out = base ^ (sum 2^i e_i)`
-- order of witness variables: [ base, e[0],...,e[n-1], output, t[0]...t[n-1] ]
exponentiationGateConstraints :: Int -> Compute ()
exponentiationGateConstraints num_power_bits =
do
let prev i = if i==0 then 1 else sqr (tmp_val (i-1))
let comp i = prev i * (cur_bit i * base + (1 - cur_bit i))
let eq i = comp i - tmp_val i
commitList [ eq i | i <- range num_power_bits ]
commit ( out - tmp_val (num_power_bits-1) )
where
base = wire 0
exp_bit i = wire (i+1)
out = wire (num_power_bits+1)
tmp_val i = wire (num_power_bits+2+i)
cur_bit i = exp_bit (num_power_bits - 1 - i)
sqr x = x*x
--------------------------------------------------------------------------------
testExpoGate = runComputation testEvaluationVarsExt (gateComputation (ExponentiationGate 13))
--------------------------------------------------------------------------------

91
src/Gate/RandomAccess.hs Normal file
View File

@ -0,0 +1,91 @@
-- | The @RandomAccess@ gate
{-# LANGUAGE StrictData, RecordWildCards #-}
module Gate.RandomAccess where
--------------------------------------------------------------------------------
import Data.Foldable
import Control.Monad
import Algebra.Goldilocks
import Algebra.Expr
import Gate.Vars
import Gate.Computation
import Misc.Aux
--------------------------------------------------------------------------------
num_routed_columns = 80 :: Int
--------------------------------------------------------------------------------
data RandomAccessGateConfig = MkRACfg
{ ra_num_bits :: Int -- ^ number of bits in the index (so the vector has width @2^n@)
, ra_num_copies :: Int -- ^ how many copies of this operation is included in a row
, ra_num_extra_constants :: Int -- ^ number of extra cells used as in ConstantGate
}
deriving Show
randomAccessGateConfig :: Int -> RandomAccessGateConfig
randomAccessGateConfig num_bits = ra_cfg where
veclen = 2 ^ num_bits
width = 2 + veclen
copies = Prelude.div num_routed_columns width
extra = min 2 (num_routed_columns - copies*width)
ra_cfg = MkRACfg
{ ra_num_bits = num_bits
, ra_num_copies = copies
, ra_num_extra_constants = extra
}
--------------------------------------------------------------------------------
randomAccessGateConstraints :: RandomAccessGateConfig -> Compute ()
randomAccessGateConstraints (MkRACfg{..}) = do
forM_ [0..ra_num_copies-1] $ \k -> do
-- index bits are actual bits
commitList [ bits k j * (bits k j - 1) | j<-[0..ra_num_bits-1] ]
-- bit decomposition is correct
let reconstr = foldr (\b acc -> 2*acc + b) 0 [ bits k j | j<-[0..ra_num_bits-1] ]
commit $ reconstr - index k
let lkp_val = lookup_eq
[ bits k j | j<-[0..ra_num_bits-1] ]
[ inputs k i | i<-[0..veclen-1] ]
commit $ lkp_val - output k
forM_ [0..ra_num_extra_constants-1] $ \j -> commit (cnst j - extra j)
where
veclen = 2 ^ ra_num_bits
width = 2 + veclen
bits_start_at = width * ra_num_copies + ra_num_extra_constants
-- witness variables
index k = wire $ k*width + 0
output k = wire $ k*width + 1
inputs k j = wire $ k*width + 2 + j
extra j = wire $ ra_num_copies * width + j
bits k j = wire $ bits_start_at + k*ra_num_bits + j
into_pairs [] = []
into_pairs (x:y:rest) = (x,y) : into_pairs rest
into_pairs [x] = error "into_pairs: odd input"
lookup_eq [] [z] = z
lookup_eq (b:bits) values = lookup_eq bits $ map (\(x,y) -> x + b*(y-x)) (into_pairs values)
lookup_eq bits values = error $ "RandomAccessGate/lookup_eq: shouldn't happen: " ++ show (length bits, length values)
--------------------------------------------------------------------------------
testRAGate = runComputation testEvaluationVarsExt (randomAccessGateConstraints $ randomAccessGateConfig 4)