initial import the Haskell WIP code (parser + interpreter works on a simple example)

This commit is contained in:
Balazs Komuves 2025-03-13 17:55:37 +01:00
parent 0fc22c2a27
commit de29a073da
No known key found for this signature in database
GPG Key ID: F63B7AEF18435562
9 changed files with 973 additions and 0 deletions

6
.gitignore vendored Normal file
View File

@ -0,0 +1,6 @@
.DS_Store
*.hi
*.o
tmp/
build/
.ghc.environment.*

55
README.md Normal file
View File

@ -0,0 +1,55 @@
Circom witness generators
-------------------------
This piece of software takes the computation graph files generated by
[`circom-witnesscalc`](https://github.com/iden3/circom-witnesscalc),
and either interprets or compiles them to various algebra backends.
### Implementation status
Compiler (in Haskell):
- [x] parsing the graph files
- [x] naive interpreter
- [ ] constantine backend
- [ ] zikkurat backend
- [ ] arkworks backend
Nim witness generator:
- [x] parsing the graph files
- [ ] generating the witness
- [ ] proper error handling
### Testing & correctness
I haven't yet done any proper testing, apart from "works for our purposes".
### Circuit optimizations
NOTE: you _have to_ run `circom` with the `--O2` options, otherwise the
witness format will be most probably incompatible with the the one generated
by `circom-witnesscalc`.
### Graph file format
`circom-witnesscalc` produces binary files encoding a computation graph.
This has the following format:
- magic header: "wtns.graph.001" (14 bytes)
- number of nodes (8 bytes little endian)
- list of nodes, in protobuf serialized format. Each prefixed by a `varint` length
- `GraphMetaData`, in protobuf
- 8 bytes offset (yes, _at the very end_...), pointing to the start of `GraphMetaData`
Node format:
- varint length prefix
- protobuf tag-byte (lower 3 bits are `0x02`, upper bits are node type 1..5)
- varint length
- several record fields
- protobuf tag-byte (lower 3 bits are 0x00, upper bits are field index 1,2,3,4)
- value is varint word32, except for ConstantNode when it's a little-endian bytes
(wrapped a few times, because protobuf is being protobuf...)

191
src/BN254.hs Normal file
View File

@ -0,0 +1,191 @@
-- | The BN254 scalar field
{-# LANGUAGE Strict, BangPatterns #-}
module BN254 where
--------------------------------------------------------------------------------
import Prelude hiding (div)
import qualified Prelude
import Data.Bits
import Data.Word
import Data.Ratio
import System.Random
import Text.Printf
import Misc
--------------------------------------------------------------------------------
fieldPrime :: Integer
fieldPrime = 21888242871839275222246405745257275088548364400416034343698204186575808495617
modP :: Integer -> Integer
modP x = mod x fieldPrime
halfPrimePlus1 :: Integer
halfPrimePlus1 = 1 + Prelude.div fieldPrime 2
--------------------------------------------------------------------------------
newtype F
= MkF Integer
deriving (Eq,Show)
fromF :: F -> Integer
fromF (MkF x) = x
-- from the circom docs: @val(z) = z-p if p/2 +1 <= z < p@
signedFromF :: F -> Integer
signedFromF (MkF x) = if x >= halfPrimePlus1 then x - fieldPrime else x
toF :: Integer -> F
toF = MkF . modP
isZero :: F -> Bool
isZero (MkF x) = (x == 0)
fromBool :: Bool -> F
fromBool b = if b then 1 else 0
toBool :: F -> Bool
toBool = not . isZero
--------------------------------------------------------------------------------
neg :: F -> F
neg (MkF x) = toF (negate x)
add :: F -> F -> F
add (MkF x) (MkF y) = toF (x+y)
sub :: F -> F -> F
sub (MkF x) (MkF y) = toF (x-y)
mul :: F -> F -> F
mul (MkF x) (MkF y) = toF (x*y)
instance Num F where
fromInteger = toF
negate = neg
(+) = add
(-) = sub
(*) = mul
abs = id
signum _ = toF 1
square :: F -> F
square x = x*x
rndF :: IO F
rndF = MkF <$> randomRIO (0,fieldPrime-1)
--------------------------------------------------------------------------------
pow :: F -> Integer -> F
pow x0 exponent
| exponent < 0 = error "power: expecting positive exponent"
| otherwise = go 1 x0 exponent
where
go !acc _ 0 = acc
go !acc s e = go acc' s' (shiftR e 1) where
s' = s*s
acc' = if e .&. 1 == 0 then acc else acc*s
invNaive :: F -> F
invNaive x = pow x (fieldPrime - 2)
divNaive :: F -> F -> F
divNaive x y = x * invNaive y
--------------------------------------------------------------------------------
instance Fractional F where
fromRational q = fromInteger (numerator q) / fromInteger (denominator q)
recip = inv
(/) = div
--------------------------------------------------------------------------------
fromBytesLE :: [Word8] -> F
fromBytesLE = toF . integerFromBytesLE
integerFromBytesLE :: [Word8] -> Integer
integerFromBytesLE = go where
go [] = 0
go (b:bs) = fromIntegral b + (shiftL (go bs) 8)
--------------------------------------------------------------------------------
instance ShowHex Integer where showHex = printf "0x%x"
instance ShowHex F where showHex (MkF x) = showHex x
--------------------------------------------------------------------------------
-- | Inversion (using Euclid's algorithm)
inv :: F -> F
inv (MkF a)
| a == 0 = 0 -- error "field inverse of zero (generic prime)"
| otherwise = MkF (euclid 1 0 a fieldPrime)
-- | Division via Euclid's algorithm
div :: F -> F -> F
div (MkF a) (MkF b)
| b == 0 = 0 -- error "field division by zero (generic prime)"
| otherwise = MkF (euclid a 0 b fieldPrime)
--------------------------------------------------------------------------------
-- * Euclidean algorithm
-- | Extended binary Euclidean algorithm
euclid :: Integer -> Integer -> Integer -> Integer -> Integer
euclid !x1 !x2 !u !v = go x1 x2 u v where
p = fieldPrime
halfp1 = shiftR (p+1) 1
modp :: Integer -> Integer
modp n = mod n p
-- Inverse using the binary Euclidean algorithm
euclid :: Integer -> Integer
euclid a
| a == 0 = 0
| otherwise = go 1 0 a p
go :: Integer -> Integer -> Integer -> Integer -> Integer
go !x1 !x2 !u !v
| u==1 = x1
| v==1 = x2
| otherwise = stepU x1 x2 u v
stepU :: Integer -> Integer -> Integer -> Integer -> Integer
stepU !x1 !x2 !u !v = if even u
then let u' = shiftR u 1
x1' = if even x1 then shiftR x1 1 else shiftR x1 1 + halfp1
in stepU x1' x2 u' v
else stepV x1 x2 u v
stepV :: Integer -> Integer -> Integer -> Integer -> Integer
stepV !x1 !x2 !u !v = if even v
then let v' = shiftR v 1
x2' = if even x2 then shiftR x2 1 else shiftR x2 1 + halfp1
in stepV x1 x2' u v'
else final x1 x2 u v
final :: Integer -> Integer -> Integer -> Integer -> Integer
final !x1 !x2 !u !v = if u>=v
then let u' = u-v
x1' = if x1 >= x2 then modp (x1-x2) else modp (x1+p-x2)
in go x1' x2 u' v
else let v' = v-u
x2' = if x2 >= x1 then modp (x2-x1) else modp (x2+p-x1)
in go x1 x2' u v'
--------------------------------------------------------------------------------

104
src/Graph.hs Normal file
View File

@ -0,0 +1,104 @@
{-# LANGUAGE StrictData #-}
module Graph where
--------------------------------------------------------------------------------
import Text.Printf
--------------------------------------------------------------------------------
import Data.Word
data Graph = Graph
{ graphNodes :: [Node]
, graphMeta :: GraphMetaData
}
deriving Show
--------------------------------------------------------------------------------
-- | Unary operations
data UnoOp
= Neg -- ^ @= 0@
| Id -- ^ @= 1@
deriving (Eq,Enum,Bounded,Show)
data DuoOp
= Mul -- ^ @= 0@
| Div -- ^ @= 1@
| Add -- ^ @= 2@
| Sub -- ^ @= 3@
| Pow -- ^ @= 4@
| Idiv -- ^ @= 5@
| Mod -- ^ @= 6@
| Eq_ -- ^ @= 7@
| Neq -- ^ @= 8@
| Lt -- ^ @= 9@
| Gt -- ^ @= 10@
| Leq -- ^ @= 11@
| Geq -- ^ @= 12@
| Land -- ^ @= 13@
| Lor -- ^ @= 14@
| Shl -- ^ @= 15@
| Shr -- ^ @= 16@
| Bor -- ^ @= 17@
| Band -- ^ @= 18@
| Bxor -- ^ @= 19@
deriving (Eq,Enum,Bounded,Show)
data TresOp
= TernCond -- ^ @= 0@
deriving (Eq,Enum,Bounded,Show)
--------------------------------------------------------------------------------
newtype BigUInt
= BigUInt [Word8] -- ^ little endian
showBigUInt :: BigUInt -> String
showBigUInt (BigUInt bytes) = "0x" ++ concatMap f (reverse bytes) where
f :: Word8 -> String
f = printf "%02x"
instance Show BigUInt where show = showBigUInt
newtype InputNode
= InputNode Word32
deriving (Show)
newtype ConstantNode
= ConstantNode BigUInt
deriving (Show)
data UnoOpNode = UnoOpNode !UnoOp !Word32 deriving (Show)
data DuoOpNode = DuoOpNode !DuoOp !Word32 !Word32 deriving (Show)
data TresOpNode = TresOpNode !TresOp !Word32 !Word32 !Word32 deriving (Show)
data Node
= AnInputNode InputNode -- @= 1@
| AConstantNode ConstantNode -- @= 2@
| AnUnoOpNode UnoOpNode -- @= 3@
| ADuoOpNode DuoOpNode -- @= 4@
| ATresOpNode TresOpNode -- @= 5@
deriving (Show)
data SignalDescription = SignalDescription
{ signalOffset :: !Word32
, signalLength :: !Word32
}
deriving (Show)
newtype WitnessMapping
= WitnessMapping { fromWitnessMapping :: [Word32] }
deriving (Show)
type CircuitInputs = [(String, SignalDescription)]
data GraphMetaData = GraphMetaData
{ witnessMapping :: WitnessMapping
, inputSignals :: CircuitInputs
}
deriving (Show)
--------------------------------------------------------------------------------

34
src/Main.hs Normal file
View File

@ -0,0 +1,34 @@
module Main where
--------------------------------------------------------------------------------
import Data.Map (Map)
import qualified Data.Map as Map
import Witness
import Parser
--------------------------------------------------------------------------------
(~>) :: String -> a -> (String, a)
(~>) = (,)
infix 2 ~>
testInputs :: Map String [Integer]
testInputs = Map.fromList
[ "a" ~> [0xff01]
, "b" ~> [0xff02]
]
--------------------------------------------------------------------------------
main :: IO ()
main = do
Right graph <- parseGraphFile "../tmp/graph2.bin"
putStrLn ""
print graph
let wtns = witnessCalc testInputs graph
putStrLn ""
print wtns

33
src/Misc.hs Normal file
View File

@ -0,0 +1,33 @@
{-# LANGUAGE Strict #-}
module Misc where
--------------------------------------------------------------------------------
import Data.Bits
--------------------------------------------------------------------------------
class ShowHex a where
showHex :: a -> String
printHex :: ShowHex a => a -> IO ()
printHex x = putStrLn (showHex x)
--------------------------------------------------------------------------------
-- Integer logarithm
-- | Largest integer @k@ such that @2^k@ is smaller or equal to @n@
integerLog2 :: Integer -> Int
integerLog2 n = go n where
go 0 = -1
go k = 1 + go (shiftR k 1)
-- | Smallest integer @k@ such that @2^k@ is larger or equal to @n@
ceilingLog2 :: Integer -> Int
ceilingLog2 0 = 0
ceilingLog2 n = 1 + go (n-1) where
go 0 = -1
go k = 1 + go (shiftR k 1)
--------------------------------------------------------------------------------

345
src/Parser.hs Normal file
View File

@ -0,0 +1,345 @@
-- | Parsing the graph binary format
{-# LANGUAGE Strict, PackageImports, BangPatterns #-}
module Parser where
--------------------------------------------------------------------------------
import Data.Bits
import Data.Word
import Data.List
import Data.Ord
import Control.Monad
import Control.Applicative
import System.IO
import Data.ByteString.Lazy (ByteString)
import qualified Data.ByteString.Lazy as L
import qualified Data.ByteString.Lazy.Char8 as LC
import "binary" Data.Binary.Get
import "binary" Data.Binary.Builder as Builder
import Graph
--------------------------------------------------------------------------------
{-
test :: IO ()
test = do
Right graph <- parseGraphFile "../tmp/graph.bin"
print graph
-}
{-
nodeEx1 = 0x06 : nodeEx1' :: [Word8]
nodeEx2 = 0x08 : nodeEx2' :: [Word8]
nodeEx1' = 0x22 : 0x04 : duoNodeEx1 :: [Word8]
nodeEx2' = 0x22 : 0x06 : duoNodeEx2 :: [Word8]
duoNodeEx1 = [ 0x10 , 0x05 , 0x18 , 0x05 ] :: [Word8]
duoNodeEx2 = [ 0x08 , 0x02 , 0x10 , 0x03 , 0x18 , 0x04 ] :: [Word8]
-}
--------------------------------------------------------------------------------
type Msg = String
parseGraphFile :: FilePath -> IO (Either Msg Graph)
parseGraphFile fname = do
h <- openBinaryFile fname ReadMode
ei <- readGraphFile h
hClose h
return ei
hGetBytes :: Handle -> Int -> IO ByteString
hGetBytes h n = L.hGet h (fromIntegral n)
hSeekInt :: Handle -> Int -> IO ()
hSeekInt h ofs = hSeek h AbsoluteSeek (fromIntegral ofs)
readGraphFile :: Handle -> IO (Either Msg Graph)
readGraphFile h = do
flen <- (fromIntegral :: Integer -> Int) <$> hFileSize h
magic <- hGetBytes h (length magicHeader)
if magic /= LC.pack magicHeader
then return $ Left "magic header not found or invalid"
else do
hSeekInt h (flen - 8)
offset <- (fromIntegral . runGet getWord64le) <$> hGetBytes h 8
putStrLn $ "metadata offset = " ++ show offset
if (offset >= flen) || (offset <= 18)
then return $ Left "invalid final `graphMetaData` offset bytes"
else do
hSeekInt h (length magicHeader)
part1 <- hGetBytes h (offset - length magicHeader)
part2 <- hGetBytes h (flen - offset - 8)
return (Right $ Graph (parseNodes part1) (parseMeta part2))
magicHeader :: String
magicHeader = "wtns.graph.001"
parseNodes :: ByteString -> [Node]
parseNodes = runGet getNodes
parseMeta :: ByteString -> GraphMetaData
parseMeta = runGet getMetaData
--------------------------------------------------------------------------------
varInt' :: Get Word64
varInt' = go 0 where
go !cnt = if cnt >= 8 then return 0 else do
w <- getWord8
if (w < 128)
then return (fromIntegral w)
else do
let x = fromIntegral (w .&. 127)
y <- go (cnt+1)
return (x + 128*y)
varInt :: Get Int
varInt = fromIntegral <$> varInt'
varUInt :: Get Word32
varUInt = fromIntegral <$> varInt'
--------------------------------------------------------------------------------
getNodes :: Get [Node]
getNodes = do
n <- getWord64le
replicateM (fromIntegral n) getNode
-- | with varint length prefix
getNode :: Get Node
getNode = do
len <- varInt
bs <- getLazyByteString (fromIntegral len)
return (runGet getNode' bs)
-- | without varint length prefix
getNode' :: Get Node
getNode' = do
nodetype <- getFieldId LEN
case nodetype of
1 -> AnInputNode <$> getInputNode
2 -> AConstantNode <$> getConstantNode
3 -> AnUnoOpNode <$> getUnoOpNode
4 -> ADuoOpNode <$> getDuoOpNode
5 -> ATresOpNode <$> getTresOpNode
_ -> error "unexpected node type"
getInputNode :: Get InputNode
getInputNode = do
SomeNode idx _ _ _ <- getSomeNode
return (InputNode idx)
getConstantNode :: Get ConstantNode
getConstantNode = do
len <- varInt
bs <- getLazyByteString (fromIntegral len)
return $ ConstantNode (runGet getBigUInt bs)
getBigUInt :: Get BigUInt
getBigUInt = do
fld <- getFieldId LEN
if fld /= 1
then error "getBigUInt"
else do
len <- varInt
bs <- getLazyByteString (fromIntegral len)
return $ BigUInt (runGet getByteList bs)
getByteList :: Get [Word8]
getByteList = do
fld <- getFieldId LEN
if fld /= 1
then error "getByteList"
else do
len <- varInt
bs <- getLazyByteString (fromIntegral len)
return (L.unpack bs)
getUnoOpNode :: Get UnoOpNode
getUnoOpNode = do
SomeNode op arg1 _ _ <- getSomeNode
return (UnoOpNode (wordToEnum op) arg1)
getDuoOpNode :: Get DuoOpNode
getDuoOpNode = do
SomeNode op arg1 arg2 _ <- getSomeNode
return (DuoOpNode (wordToEnum op) arg1 arg2)
getTresOpNode :: Get TresOpNode
getTresOpNode = do
SomeNode op arg1 arg2 arg3 <- getSomeNode
return (TresOpNode (wordToEnum op) arg1 arg2 arg3)
wordToEnum :: Enum a => Word32 -> a
wordToEnum w = toEnum (fromIntegral w)
--------------------------------------------------------------------------------
data SomeNode = SomeNode
{ field1 :: Word32
, field2 :: Word32
, field3 :: Word32
, field4 :: Word32
}
deriving Show
defaultSomeNode :: SomeNode
defaultSomeNode = SomeNode 0 0 0 0
insert1 :: (Int,Word32) -> SomeNode -> SomeNode
insert1 (idx,val) old = case idx of
1 -> old { field1 = val }
2 -> old { field2 = val }
3 -> old { field3 = val }
4 -> old { field4 = val }
insertMany :: [(Int,Word32)] -> SomeNode -> SomeNode
insertMany list old = foldl' (flip insert1) old list
getSomeNode :: Get SomeNode
getSomeNode = do
len <- varInt
bs <- getLazyByteString (fromIntegral len)
let list = runGet getRecord bs
return $ insertMany list defaultSomeNode
--------------------------------------------------------------------------------
-- TODO: refactor this mess
getMetaData :: Get GraphMetaData
getMetaData = do
len <- varInt
mapping <- getWitnessMapping
inputs <- getCircuitInputs
return $ GraphMetaData mapping inputs
getWitnessMapping :: Get WitnessMapping
getWitnessMapping = do
fld <- getFieldId LEN
if fld /= 1
then error "getWitnessMapping: expecting field 1"
else do
len <- varInt
bs <- getLazyByteString (fromIntegral len)
return $ WitnessMapping (runGet worker bs)
where
worker :: Get [Word32]
worker = isEmpty >>= \b -> if b
then return []
else (:) <$> varUInt <*> worker
getCircuitInputs :: Get CircuitInputs
getCircuitInputs = worker where
worker :: Get [(String, SignalDescription)]
worker = isEmpty >>= \b -> if b
then return []
else (:) <$> getSingleInput <*> worker
getSingleInput :: Get (String, SignalDescription)
getSingleInput = do
fld <- getFieldId LEN
if fld /= 2
then error "getCircuitInputs: expecting field 2"
else do
len <- varInt
bs <- getLazyByteString (fromIntegral len)
return $ runGet inputHelper bs
inputHelper = do
name <- getName
signal <- getSignal
return (name,signal)
getName :: Get String
getName = do
fld <- getFieldId LEN
if fld /= 1
then error "getCircuitInputs/getName: expecting field 1"
else do
len <- varInt
bs <- getLazyByteString (fromIntegral len)
return (LC.unpack bs)
getSignal :: Get SignalDescription
getSignal = do
fld <- getFieldId LEN
if fld /= 2
then error "getCircuitInputs/getSignal: expecting field 2"
else do
len <- varInt
bs <- getLazyByteString (fromIntegral len)
return $ runGet signalHelper bs
signalHelper = do
ofs <- getSignalOffset
len <- getSignalLength
return $ SignalDescription { signalOffset = ofs , signalLength = len }
getSignalOffset = do
fld <- getFieldId VARINT
if fld /= 1
then error "getCircuitInputs/getSignalOffset: expecting field 1"
else varUInt
getSignalLength = do
fld <- getFieldId VARINT
if fld /= 2
then error "getCircuitInputs/getSignalLength: expecting field 2"
else varUInt
--------------------------------------------------------------------------------
-- * protobuf stuff
-- | There are six wire types: VARINT, I64, LEN, SGROUP, EGROUP, and I32
data WireType
= VARINT -- ^ used for: int32, int64, uint32, uint64, sint32, sint64, bool, enum
| I64 -- ^ used for: fixed64, sfixed64, double
| LEN -- ^ used for: string, bytes, embedded messages, packed repeated fields
| SGROUP -- ^ used for: group start (deprecated)
| EGROUP -- ^ used for: group end (deprecated)
| I32 -- ^ fixed32, sfixed32, float
deriving (Eq,Show,Enum,Bounded)
type FieldId = Int
getFieldId :: WireType -> Get FieldId
getFieldId wty = do
tag <- getWord8
let (fld,wty') = decodeTag tag
if wty == wty'
then return fld
else error "getFieldId: unexpected protobuf wire type"
decodeTag_ :: Word8 -> FieldId
decodeTag_ = fst . decodeTag
decodeTag :: Word8 -> (FieldId, WireType)
decodeTag w = (fld , wty) where
fld = fromIntegral (shiftR w 3)
wty = toEnum (fromIntegral (w .&. 7))
-- (index, value) pair
getEntry :: Get (Int,Word32)
getEntry = do
idx <- getFieldId VARINT
val <- varUInt
return (idx,val)
-- list of (index, value) pairs
getRecord :: Get [(Int,Word32)]
getRecord = sort <$> go where
go = isEmpty >>= \b -> if b
then return []
else (:) <$> getEntry <*> go
--------------------------------------------------------------------------------

101
src/Semantics.hs Normal file
View File

@ -0,0 +1,101 @@
-- | See <https://docs.circom.io/circom-language/basic-operators> for the official
-- semantics of the operations
{-# LANGUAGE StrictData, DeriveFunctor #-}
module Semantics where
--------------------------------------------------------------------------------
import Data.Bits
import BN254
import Misc
--------------------------------------------------------------------------------
data PrimOp a
-- unary
= Neg a
| Id a
-- binary
| Mul a a
| Div a a
| Add a a
| Sub a a
| Pow a a
| Idiv a a
| Mod a a
| Eq_ a a
| Neq a a
| Lt a a
| Gt a a
| Leq a a
| Geq a a
| Land a a
| Lor a a
| Shl a a
| Shr a a
| Bor a a
| Band a a
| Bxor a a
-- ternary
| Cond a a a
deriving (Show,Functor)
--------------------------------------------------------------------------------
evalPrimOp :: PrimOp F -> F
evalPrimOp prim = case prim of
Neg x -> neg x
Id x -> x
Mul x y -> mul x y
Div x y -> BN254.div x y
Add x y -> add x y
Sub x y -> sub x y
Pow x y -> pow x (fromF y)
Idiv x y -> if isZero y then 0 else toF $ Prelude.div (fromF x) (fromF y)
Mod x y -> if isZero y then 0 else toF $ Prelude.mod (fromF x) (fromF y)
Eq_ x y -> fromBool (x == y)
Neq x y -> fromBool (x /= y)
Lt x y -> fromBool (signedFromF x < signedFromF y)
Gt x y -> fromBool (signedFromF x > signedFromF y)
Leq x y -> fromBool (signedFromF x <= signedFromF y)
Geq x y -> fromBool (signedFromF x >= signedFromF y)
Land x y -> fromBool (toBool x && toBool y)
Lor x y -> fromBool (toBool x || toBool y)
Shl x y -> shiftLeft x (fromF y)
Shr x y -> shiftRight x (fromF y)
Bor x y -> toF (fromF x .|. fromF y)
Band x y -> toF (fromF x .&. fromF y)
Bxor x y -> toF (fromF x `xor` fromF y)
Cond b x y -> if toBool b then x else y
--------------------------------------------------------------------------------
fieldMask :: Integer
fieldMask = shiftL 1 fieldBits - 1
fieldBits :: Int
fieldBits = ceilingLog2 fieldPrime
shiftRight :: F -> Integer -> F
shiftRight (MkF x) k = if k < halfPrimePlus1
then if k > fromIntegral fieldBits
then 0
else MkF (shiftR x (fromInteger k))
else shiftLeft (MkF x) (fieldPrime - k)
shiftLeft :: F -> Integer -> F
shiftLeft (MkF x) k = if k < halfPrimePlus1
then if k > fromIntegral fieldBits
then 0
else toF (shiftL x (fromIntegral k) .&. fieldMask)
else shiftRight (MkF x) (fieldPrime - k)
--------------------------------------------------------------------------------
{-
notYet :: a
notYet = error "not yet implemented"
-}

104
src/Witness.hs Normal file
View File

@ -0,0 +1,104 @@
{-# LANGUAGE StrictData #-}
module Witness where
--------------------------------------------------------------------------------
import Data.Array
import Data.Word
import qualified Data.Map as Map ; import Data.Map (Map )
import qualified Data.IntMap as IntMap ; import Data.IntMap (IntMap)
import BN254
import qualified Semantics as S ; import Semantics ( PrimOp , evalPrimOp )
import qualified Graph as G ; import Graph ( Graph(..) , Node(..) , UnoOpNode(..) , DuoOpNode(..) , TresOpNode(..) , SignalDescription(..) )
--------------------------------------------------------------------------------
type Witness = Array Int F
type Inputs = Map String [Integer]
witnessCalc :: Inputs -> Graph -> Witness
witnessCalc inputs (Graph nodes meta) = witness where
nodesArr = listArray (0,length nodes-1) nodes
rawWitness = evaluateNodes rawInputs nodesArr
rawInputs = convertInputs (G.inputSignals meta) inputs
mapping_ = G.fromWitnessMapping (G.witnessMapping meta)
wtnslen = length mapping_
mapping = listArray (0,wtnslen-1) mapping_
witness = listArray (0,wtnslen-1) $ [ rawWitness!(fromIntegral (mapping!i)) | i<-[0..wtnslen-1] ]
convertInputs :: [(String,SignalDescription)] -> Map String [Integer] -> IntMap F
convertInputs descTable inputTable = IntMap.fromList $ concatMap f descTable where
f :: (String,SignalDescription) -> [(Int,F)]
f (name,desc) = case Map.lookup name inputTable of
Nothing -> error $ "input signal `" ++ name ++ "` not found in the given inputs!"
Just values -> if length values /= fromIntegral (signalLength desc)
then error $ "input signal `" ++ name ++ "` has incorrect size"
else let ofs = fromIntegral (signalOffset desc) :: Int
in (0,1) : zip [ofs..] (map toF values)
--------------------------------------------------------------------------------
type RawInputs = IntMap F
evaluateNodes :: RawInputs -> Array Int Node -> Array Int F
evaluateNodes inputs nodes = witness where
(a,b) = bounds nodes
witness = array (0,b-a) $ [ (i-a, worker i) | i<-[a..b] ]
lkp :: Word32 -> F
lkp i = witness!(fromIntegral i)
worker i = case nodes!i of
AnInputNode (G.InputNode idx) -> getInputValue idx
AConstantNode (G.ConstantNode big) -> fromBigUInt big
AnUnoOpNode unoOpNode -> evalPrimOp (fmap lkp $ unoToPrimOp unoOpNode )
ADuoOpNode duoOpNode -> evalPrimOp (fmap lkp $ duoToPrimOp duoOpNode )
ATresOpNode tresOpNode -> evalPrimOp (fmap lkp $ tresToPrimOp tresOpNode)
getInputValue j = case IntMap.lookup (fromIntegral j) inputs of
Just y -> y
Nothing -> error ("input value not found at index " ++ show j)
fromBigUInt (G.BigUInt bs) = fromBytesLE bs
--------------------------------------------------------------------------------
unoToPrimOp :: UnoOpNode -> PrimOp Word32
unoToPrimOp (UnoOpNode op arg1) = case op of
G.Neg -> S.Neg arg1
G.Id -> S.Id arg1
duoToPrimOp :: DuoOpNode -> PrimOp Word32
duoToPrimOp (DuoOpNode op arg1 arg2) = case op of
G.Mul -> S.Mul arg1 arg2
G.Div -> S.Div arg1 arg2
G.Add -> S.Add arg1 arg2
G.Sub -> S.Sub arg1 arg2
G.Pow -> S.Pow arg1 arg2
G.Idiv -> S.Idiv arg1 arg2
G.Mod -> S.Mod arg1 arg2
G.Eq_ -> S.Eq_ arg1 arg2
G.Neq -> S.Neq arg1 arg2
G.Lt -> S.Lt arg1 arg2
G.Gt -> S.Gt arg1 arg2
G.Leq -> S.Leq arg1 arg2
G.Geq -> S.Geq arg1 arg2
G.Land -> S.Land arg1 arg2
G.Lor -> S.Lor arg1 arg2
G.Shl -> S.Shl arg1 arg2
G.Shr -> S.Shr arg1 arg2
G.Bor -> S.Bor arg1 arg2
G.Band -> S.Band arg1 arg2
G.Bxor -> S.Bxor arg1 arg2
tresToPrimOp :: TresOpNode -> PrimOp Word32
tresToPrimOp (TresOpNode op arg1 arg2 arg3) = case op of
G.TernCond -> S.Cond arg1 arg2 arg3
--------------------------------------------------------------------------------