mirror of
https://github.com/logos-storage/circom-witnessgen.git
synced 2026-01-04 05:53:05 +00:00
initial import the Haskell WIP code (parser + interpreter works on a simple example)
This commit is contained in:
parent
0fc22c2a27
commit
de29a073da
6
.gitignore
vendored
Normal file
6
.gitignore
vendored
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
.DS_Store
|
||||||
|
*.hi
|
||||||
|
*.o
|
||||||
|
tmp/
|
||||||
|
build/
|
||||||
|
.ghc.environment.*
|
||||||
55
README.md
Normal file
55
README.md
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
|
||||||
|
Circom witness generators
|
||||||
|
-------------------------
|
||||||
|
|
||||||
|
This piece of software takes the computation graph files generated by
|
||||||
|
[`circom-witnesscalc`](https://github.com/iden3/circom-witnesscalc),
|
||||||
|
and either interprets or compiles them to various algebra backends.
|
||||||
|
|
||||||
|
### Implementation status
|
||||||
|
|
||||||
|
Compiler (in Haskell):
|
||||||
|
|
||||||
|
- [x] parsing the graph files
|
||||||
|
- [x] naive interpreter
|
||||||
|
- [ ] constantine backend
|
||||||
|
- [ ] zikkurat backend
|
||||||
|
- [ ] arkworks backend
|
||||||
|
|
||||||
|
Nim witness generator:
|
||||||
|
|
||||||
|
- [x] parsing the graph files
|
||||||
|
- [ ] generating the witness
|
||||||
|
- [ ] proper error handling
|
||||||
|
|
||||||
|
### Testing & correctness
|
||||||
|
|
||||||
|
I haven't yet done any proper testing, apart from "works for our purposes".
|
||||||
|
|
||||||
|
### Circuit optimizations
|
||||||
|
|
||||||
|
NOTE: you _have to_ run `circom` with the `--O2` options, otherwise the
|
||||||
|
witness format will be most probably incompatible with the the one generated
|
||||||
|
by `circom-witnesscalc`.
|
||||||
|
|
||||||
|
### Graph file format
|
||||||
|
|
||||||
|
`circom-witnesscalc` produces binary files encoding a computation graph.
|
||||||
|
|
||||||
|
This has the following format:
|
||||||
|
|
||||||
|
- magic header: "wtns.graph.001" (14 bytes)
|
||||||
|
- number of nodes (8 bytes little endian)
|
||||||
|
- list of nodes, in protobuf serialized format. Each prefixed by a `varint` length
|
||||||
|
- `GraphMetaData`, in protobuf
|
||||||
|
- 8 bytes offset (yes, _at the very end_...), pointing to the start of `GraphMetaData`
|
||||||
|
|
||||||
|
Node format:
|
||||||
|
|
||||||
|
- varint length prefix
|
||||||
|
- protobuf tag-byte (lower 3 bits are `0x02`, upper bits are node type 1..5)
|
||||||
|
- varint length
|
||||||
|
- several record fields
|
||||||
|
- protobuf tag-byte (lower 3 bits are 0x00, upper bits are field index 1,2,3,4)
|
||||||
|
- value is varint word32, except for ConstantNode when it's a little-endian bytes
|
||||||
|
(wrapped a few times, because protobuf is being protobuf...)
|
||||||
191
src/BN254.hs
Normal file
191
src/BN254.hs
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
|
||||||
|
-- | The BN254 scalar field
|
||||||
|
|
||||||
|
{-# LANGUAGE Strict, BangPatterns #-}
|
||||||
|
module BN254 where
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
import Prelude hiding (div)
|
||||||
|
import qualified Prelude
|
||||||
|
|
||||||
|
import Data.Bits
|
||||||
|
import Data.Word
|
||||||
|
import Data.Ratio
|
||||||
|
|
||||||
|
import System.Random
|
||||||
|
import Text.Printf
|
||||||
|
|
||||||
|
import Misc
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
fieldPrime :: Integer
|
||||||
|
fieldPrime = 21888242871839275222246405745257275088548364400416034343698204186575808495617
|
||||||
|
|
||||||
|
modP :: Integer -> Integer
|
||||||
|
modP x = mod x fieldPrime
|
||||||
|
|
||||||
|
halfPrimePlus1 :: Integer
|
||||||
|
halfPrimePlus1 = 1 + Prelude.div fieldPrime 2
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
newtype F
|
||||||
|
= MkF Integer
|
||||||
|
deriving (Eq,Show)
|
||||||
|
|
||||||
|
fromF :: F -> Integer
|
||||||
|
fromF (MkF x) = x
|
||||||
|
|
||||||
|
-- from the circom docs: @val(z) = z-p if p/2 +1 <= z < p@
|
||||||
|
signedFromF :: F -> Integer
|
||||||
|
signedFromF (MkF x) = if x >= halfPrimePlus1 then x - fieldPrime else x
|
||||||
|
|
||||||
|
toF :: Integer -> F
|
||||||
|
toF = MkF . modP
|
||||||
|
|
||||||
|
isZero :: F -> Bool
|
||||||
|
isZero (MkF x) = (x == 0)
|
||||||
|
|
||||||
|
fromBool :: Bool -> F
|
||||||
|
fromBool b = if b then 1 else 0
|
||||||
|
|
||||||
|
toBool :: F -> Bool
|
||||||
|
toBool = not . isZero
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
neg :: F -> F
|
||||||
|
neg (MkF x) = toF (negate x)
|
||||||
|
|
||||||
|
add :: F -> F -> F
|
||||||
|
add (MkF x) (MkF y) = toF (x+y)
|
||||||
|
|
||||||
|
sub :: F -> F -> F
|
||||||
|
sub (MkF x) (MkF y) = toF (x-y)
|
||||||
|
|
||||||
|
mul :: F -> F -> F
|
||||||
|
mul (MkF x) (MkF y) = toF (x*y)
|
||||||
|
|
||||||
|
instance Num F where
|
||||||
|
fromInteger = toF
|
||||||
|
negate = neg
|
||||||
|
(+) = add
|
||||||
|
(-) = sub
|
||||||
|
(*) = mul
|
||||||
|
abs = id
|
||||||
|
signum _ = toF 1
|
||||||
|
|
||||||
|
square :: F -> F
|
||||||
|
square x = x*x
|
||||||
|
|
||||||
|
rndF :: IO F
|
||||||
|
rndF = MkF <$> randomRIO (0,fieldPrime-1)
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
pow :: F -> Integer -> F
|
||||||
|
pow x0 exponent
|
||||||
|
| exponent < 0 = error "power: expecting positive exponent"
|
||||||
|
| otherwise = go 1 x0 exponent
|
||||||
|
where
|
||||||
|
go !acc _ 0 = acc
|
||||||
|
go !acc s e = go acc' s' (shiftR e 1) where
|
||||||
|
s' = s*s
|
||||||
|
acc' = if e .&. 1 == 0 then acc else acc*s
|
||||||
|
|
||||||
|
invNaive :: F -> F
|
||||||
|
invNaive x = pow x (fieldPrime - 2)
|
||||||
|
|
||||||
|
divNaive :: F -> F -> F
|
||||||
|
divNaive x y = x * invNaive y
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
instance Fractional F where
|
||||||
|
fromRational q = fromInteger (numerator q) / fromInteger (denominator q)
|
||||||
|
recip = inv
|
||||||
|
(/) = div
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
fromBytesLE :: [Word8] -> F
|
||||||
|
fromBytesLE = toF . integerFromBytesLE
|
||||||
|
|
||||||
|
integerFromBytesLE :: [Word8] -> Integer
|
||||||
|
integerFromBytesLE = go where
|
||||||
|
go [] = 0
|
||||||
|
go (b:bs) = fromIntegral b + (shiftL (go bs) 8)
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
instance ShowHex Integer where showHex = printf "0x%x"
|
||||||
|
instance ShowHex F where showHex (MkF x) = showHex x
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
-- | Inversion (using Euclid's algorithm)
|
||||||
|
inv :: F -> F
|
||||||
|
inv (MkF a)
|
||||||
|
| a == 0 = 0 -- error "field inverse of zero (generic prime)"
|
||||||
|
| otherwise = MkF (euclid 1 0 a fieldPrime)
|
||||||
|
|
||||||
|
-- | Division via Euclid's algorithm
|
||||||
|
div :: F -> F -> F
|
||||||
|
div (MkF a) (MkF b)
|
||||||
|
| b == 0 = 0 -- error "field division by zero (generic prime)"
|
||||||
|
| otherwise = MkF (euclid a 0 b fieldPrime)
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
-- * Euclidean algorithm
|
||||||
|
|
||||||
|
-- | Extended binary Euclidean algorithm
|
||||||
|
euclid :: Integer -> Integer -> Integer -> Integer -> Integer
|
||||||
|
euclid !x1 !x2 !u !v = go x1 x2 u v where
|
||||||
|
|
||||||
|
p = fieldPrime
|
||||||
|
|
||||||
|
halfp1 = shiftR (p+1) 1
|
||||||
|
|
||||||
|
modp :: Integer -> Integer
|
||||||
|
modp n = mod n p
|
||||||
|
|
||||||
|
-- Inverse using the binary Euclidean algorithm
|
||||||
|
euclid :: Integer -> Integer
|
||||||
|
euclid a
|
||||||
|
| a == 0 = 0
|
||||||
|
| otherwise = go 1 0 a p
|
||||||
|
|
||||||
|
go :: Integer -> Integer -> Integer -> Integer -> Integer
|
||||||
|
go !x1 !x2 !u !v
|
||||||
|
| u==1 = x1
|
||||||
|
| v==1 = x2
|
||||||
|
| otherwise = stepU x1 x2 u v
|
||||||
|
|
||||||
|
stepU :: Integer -> Integer -> Integer -> Integer -> Integer
|
||||||
|
stepU !x1 !x2 !u !v = if even u
|
||||||
|
then let u' = shiftR u 1
|
||||||
|
x1' = if even x1 then shiftR x1 1 else shiftR x1 1 + halfp1
|
||||||
|
in stepU x1' x2 u' v
|
||||||
|
else stepV x1 x2 u v
|
||||||
|
|
||||||
|
stepV :: Integer -> Integer -> Integer -> Integer -> Integer
|
||||||
|
stepV !x1 !x2 !u !v = if even v
|
||||||
|
then let v' = shiftR v 1
|
||||||
|
x2' = if even x2 then shiftR x2 1 else shiftR x2 1 + halfp1
|
||||||
|
in stepV x1 x2' u v'
|
||||||
|
else final x1 x2 u v
|
||||||
|
|
||||||
|
final :: Integer -> Integer -> Integer -> Integer -> Integer
|
||||||
|
final !x1 !x2 !u !v = if u>=v
|
||||||
|
|
||||||
|
then let u' = u-v
|
||||||
|
x1' = if x1 >= x2 then modp (x1-x2) else modp (x1+p-x2)
|
||||||
|
in go x1' x2 u' v
|
||||||
|
|
||||||
|
else let v' = v-u
|
||||||
|
x2' = if x2 >= x1 then modp (x2-x1) else modp (x2+p-x1)
|
||||||
|
in go x1 x2' u v'
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
104
src/Graph.hs
Normal file
104
src/Graph.hs
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
|
||||||
|
{-# LANGUAGE StrictData #-}
|
||||||
|
module Graph where
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
import Text.Printf
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
import Data.Word
|
||||||
|
|
||||||
|
data Graph = Graph
|
||||||
|
{ graphNodes :: [Node]
|
||||||
|
, graphMeta :: GraphMetaData
|
||||||
|
}
|
||||||
|
deriving Show
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
-- | Unary operations
|
||||||
|
data UnoOp
|
||||||
|
= Neg -- ^ @= 0@
|
||||||
|
| Id -- ^ @= 1@
|
||||||
|
deriving (Eq,Enum,Bounded,Show)
|
||||||
|
|
||||||
|
data DuoOp
|
||||||
|
= Mul -- ^ @= 0@
|
||||||
|
| Div -- ^ @= 1@
|
||||||
|
| Add -- ^ @= 2@
|
||||||
|
| Sub -- ^ @= 3@
|
||||||
|
| Pow -- ^ @= 4@
|
||||||
|
| Idiv -- ^ @= 5@
|
||||||
|
| Mod -- ^ @= 6@
|
||||||
|
| Eq_ -- ^ @= 7@
|
||||||
|
| Neq -- ^ @= 8@
|
||||||
|
| Lt -- ^ @= 9@
|
||||||
|
| Gt -- ^ @= 10@
|
||||||
|
| Leq -- ^ @= 11@
|
||||||
|
| Geq -- ^ @= 12@
|
||||||
|
| Land -- ^ @= 13@
|
||||||
|
| Lor -- ^ @= 14@
|
||||||
|
| Shl -- ^ @= 15@
|
||||||
|
| Shr -- ^ @= 16@
|
||||||
|
| Bor -- ^ @= 17@
|
||||||
|
| Band -- ^ @= 18@
|
||||||
|
| Bxor -- ^ @= 19@
|
||||||
|
deriving (Eq,Enum,Bounded,Show)
|
||||||
|
|
||||||
|
data TresOp
|
||||||
|
= TernCond -- ^ @= 0@
|
||||||
|
deriving (Eq,Enum,Bounded,Show)
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
newtype BigUInt
|
||||||
|
= BigUInt [Word8] -- ^ little endian
|
||||||
|
|
||||||
|
showBigUInt :: BigUInt -> String
|
||||||
|
showBigUInt (BigUInt bytes) = "0x" ++ concatMap f (reverse bytes) where
|
||||||
|
f :: Word8 -> String
|
||||||
|
f = printf "%02x"
|
||||||
|
|
||||||
|
instance Show BigUInt where show = showBigUInt
|
||||||
|
|
||||||
|
newtype InputNode
|
||||||
|
= InputNode Word32
|
||||||
|
deriving (Show)
|
||||||
|
|
||||||
|
newtype ConstantNode
|
||||||
|
= ConstantNode BigUInt
|
||||||
|
deriving (Show)
|
||||||
|
|
||||||
|
data UnoOpNode = UnoOpNode !UnoOp !Word32 deriving (Show)
|
||||||
|
data DuoOpNode = DuoOpNode !DuoOp !Word32 !Word32 deriving (Show)
|
||||||
|
data TresOpNode = TresOpNode !TresOp !Word32 !Word32 !Word32 deriving (Show)
|
||||||
|
|
||||||
|
data Node
|
||||||
|
= AnInputNode InputNode -- @= 1@
|
||||||
|
| AConstantNode ConstantNode -- @= 2@
|
||||||
|
| AnUnoOpNode UnoOpNode -- @= 3@
|
||||||
|
| ADuoOpNode DuoOpNode -- @= 4@
|
||||||
|
| ATresOpNode TresOpNode -- @= 5@
|
||||||
|
deriving (Show)
|
||||||
|
|
||||||
|
data SignalDescription = SignalDescription
|
||||||
|
{ signalOffset :: !Word32
|
||||||
|
, signalLength :: !Word32
|
||||||
|
}
|
||||||
|
deriving (Show)
|
||||||
|
|
||||||
|
newtype WitnessMapping
|
||||||
|
= WitnessMapping { fromWitnessMapping :: [Word32] }
|
||||||
|
deriving (Show)
|
||||||
|
|
||||||
|
type CircuitInputs = [(String, SignalDescription)]
|
||||||
|
|
||||||
|
data GraphMetaData = GraphMetaData
|
||||||
|
{ witnessMapping :: WitnessMapping
|
||||||
|
, inputSignals :: CircuitInputs
|
||||||
|
}
|
||||||
|
deriving (Show)
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
34
src/Main.hs
Normal file
34
src/Main.hs
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
|
||||||
|
module Main where
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
import Data.Map (Map)
|
||||||
|
import qualified Data.Map as Map
|
||||||
|
|
||||||
|
import Witness
|
||||||
|
import Parser
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
(~>) :: String -> a -> (String, a)
|
||||||
|
(~>) = (,)
|
||||||
|
|
||||||
|
infix 2 ~>
|
||||||
|
|
||||||
|
testInputs :: Map String [Integer]
|
||||||
|
testInputs = Map.fromList
|
||||||
|
[ "a" ~> [0xff01]
|
||||||
|
, "b" ~> [0xff02]
|
||||||
|
]
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
main :: IO ()
|
||||||
|
main = do
|
||||||
|
Right graph <- parseGraphFile "../tmp/graph2.bin"
|
||||||
|
putStrLn ""
|
||||||
|
print graph
|
||||||
|
let wtns = witnessCalc testInputs graph
|
||||||
|
putStrLn ""
|
||||||
|
print wtns
|
||||||
33
src/Misc.hs
Normal file
33
src/Misc.hs
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
|
||||||
|
{-# LANGUAGE Strict #-}
|
||||||
|
module Misc where
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
import Data.Bits
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class ShowHex a where
|
||||||
|
showHex :: a -> String
|
||||||
|
|
||||||
|
printHex :: ShowHex a => a -> IO ()
|
||||||
|
printHex x = putStrLn (showHex x)
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
-- Integer logarithm
|
||||||
|
|
||||||
|
-- | Largest integer @k@ such that @2^k@ is smaller or equal to @n@
|
||||||
|
integerLog2 :: Integer -> Int
|
||||||
|
integerLog2 n = go n where
|
||||||
|
go 0 = -1
|
||||||
|
go k = 1 + go (shiftR k 1)
|
||||||
|
|
||||||
|
-- | Smallest integer @k@ such that @2^k@ is larger or equal to @n@
|
||||||
|
ceilingLog2 :: Integer -> Int
|
||||||
|
ceilingLog2 0 = 0
|
||||||
|
ceilingLog2 n = 1 + go (n-1) where
|
||||||
|
go 0 = -1
|
||||||
|
go k = 1 + go (shiftR k 1)
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
345
src/Parser.hs
Normal file
345
src/Parser.hs
Normal file
@ -0,0 +1,345 @@
|
|||||||
|
|
||||||
|
-- | Parsing the graph binary format
|
||||||
|
|
||||||
|
{-# LANGUAGE Strict, PackageImports, BangPatterns #-}
|
||||||
|
module Parser where
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
import Data.Bits
|
||||||
|
import Data.Word
|
||||||
|
import Data.List
|
||||||
|
import Data.Ord
|
||||||
|
|
||||||
|
import Control.Monad
|
||||||
|
import Control.Applicative
|
||||||
|
import System.IO
|
||||||
|
|
||||||
|
import Data.ByteString.Lazy (ByteString)
|
||||||
|
import qualified Data.ByteString.Lazy as L
|
||||||
|
import qualified Data.ByteString.Lazy.Char8 as LC
|
||||||
|
|
||||||
|
import "binary" Data.Binary.Get
|
||||||
|
import "binary" Data.Binary.Builder as Builder
|
||||||
|
|
||||||
|
import Graph
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
{-
|
||||||
|
test :: IO ()
|
||||||
|
test = do
|
||||||
|
Right graph <- parseGraphFile "../tmp/graph.bin"
|
||||||
|
print graph
|
||||||
|
-}
|
||||||
|
|
||||||
|
{-
|
||||||
|
nodeEx1 = 0x06 : nodeEx1' :: [Word8]
|
||||||
|
nodeEx2 = 0x08 : nodeEx2' :: [Word8]
|
||||||
|
|
||||||
|
nodeEx1' = 0x22 : 0x04 : duoNodeEx1 :: [Word8]
|
||||||
|
nodeEx2' = 0x22 : 0x06 : duoNodeEx2 :: [Word8]
|
||||||
|
|
||||||
|
duoNodeEx1 = [ 0x10 , 0x05 , 0x18 , 0x05 ] :: [Word8]
|
||||||
|
duoNodeEx2 = [ 0x08 , 0x02 , 0x10 , 0x03 , 0x18 , 0x04 ] :: [Word8]
|
||||||
|
-}
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type Msg = String
|
||||||
|
|
||||||
|
parseGraphFile :: FilePath -> IO (Either Msg Graph)
|
||||||
|
parseGraphFile fname = do
|
||||||
|
h <- openBinaryFile fname ReadMode
|
||||||
|
ei <- readGraphFile h
|
||||||
|
hClose h
|
||||||
|
return ei
|
||||||
|
|
||||||
|
hGetBytes :: Handle -> Int -> IO ByteString
|
||||||
|
hGetBytes h n = L.hGet h (fromIntegral n)
|
||||||
|
|
||||||
|
hSeekInt :: Handle -> Int -> IO ()
|
||||||
|
hSeekInt h ofs = hSeek h AbsoluteSeek (fromIntegral ofs)
|
||||||
|
|
||||||
|
readGraphFile :: Handle -> IO (Either Msg Graph)
|
||||||
|
readGraphFile h = do
|
||||||
|
flen <- (fromIntegral :: Integer -> Int) <$> hFileSize h
|
||||||
|
magic <- hGetBytes h (length magicHeader)
|
||||||
|
if magic /= LC.pack magicHeader
|
||||||
|
then return $ Left "magic header not found or invalid"
|
||||||
|
else do
|
||||||
|
hSeekInt h (flen - 8)
|
||||||
|
offset <- (fromIntegral . runGet getWord64le) <$> hGetBytes h 8
|
||||||
|
putStrLn $ "metadata offset = " ++ show offset
|
||||||
|
if (offset >= flen) || (offset <= 18)
|
||||||
|
then return $ Left "invalid final `graphMetaData` offset bytes"
|
||||||
|
else do
|
||||||
|
hSeekInt h (length magicHeader)
|
||||||
|
part1 <- hGetBytes h (offset - length magicHeader)
|
||||||
|
part2 <- hGetBytes h (flen - offset - 8)
|
||||||
|
return (Right $ Graph (parseNodes part1) (parseMeta part2))
|
||||||
|
|
||||||
|
magicHeader :: String
|
||||||
|
magicHeader = "wtns.graph.001"
|
||||||
|
|
||||||
|
parseNodes :: ByteString -> [Node]
|
||||||
|
parseNodes = runGet getNodes
|
||||||
|
|
||||||
|
parseMeta :: ByteString -> GraphMetaData
|
||||||
|
parseMeta = runGet getMetaData
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
varInt' :: Get Word64
|
||||||
|
varInt' = go 0 where
|
||||||
|
go !cnt = if cnt >= 8 then return 0 else do
|
||||||
|
w <- getWord8
|
||||||
|
if (w < 128)
|
||||||
|
then return (fromIntegral w)
|
||||||
|
else do
|
||||||
|
let x = fromIntegral (w .&. 127)
|
||||||
|
y <- go (cnt+1)
|
||||||
|
return (x + 128*y)
|
||||||
|
|
||||||
|
varInt :: Get Int
|
||||||
|
varInt = fromIntegral <$> varInt'
|
||||||
|
|
||||||
|
varUInt :: Get Word32
|
||||||
|
varUInt = fromIntegral <$> varInt'
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
getNodes :: Get [Node]
|
||||||
|
getNodes = do
|
||||||
|
n <- getWord64le
|
||||||
|
replicateM (fromIntegral n) getNode
|
||||||
|
|
||||||
|
-- | with varint length prefix
|
||||||
|
getNode :: Get Node
|
||||||
|
getNode = do
|
||||||
|
len <- varInt
|
||||||
|
bs <- getLazyByteString (fromIntegral len)
|
||||||
|
return (runGet getNode' bs)
|
||||||
|
|
||||||
|
-- | without varint length prefix
|
||||||
|
getNode' :: Get Node
|
||||||
|
getNode' = do
|
||||||
|
nodetype <- getFieldId LEN
|
||||||
|
case nodetype of
|
||||||
|
1 -> AnInputNode <$> getInputNode
|
||||||
|
2 -> AConstantNode <$> getConstantNode
|
||||||
|
3 -> AnUnoOpNode <$> getUnoOpNode
|
||||||
|
4 -> ADuoOpNode <$> getDuoOpNode
|
||||||
|
5 -> ATresOpNode <$> getTresOpNode
|
||||||
|
_ -> error "unexpected node type"
|
||||||
|
|
||||||
|
getInputNode :: Get InputNode
|
||||||
|
getInputNode = do
|
||||||
|
SomeNode idx _ _ _ <- getSomeNode
|
||||||
|
return (InputNode idx)
|
||||||
|
|
||||||
|
getConstantNode :: Get ConstantNode
|
||||||
|
getConstantNode = do
|
||||||
|
len <- varInt
|
||||||
|
bs <- getLazyByteString (fromIntegral len)
|
||||||
|
return $ ConstantNode (runGet getBigUInt bs)
|
||||||
|
|
||||||
|
getBigUInt :: Get BigUInt
|
||||||
|
getBigUInt = do
|
||||||
|
fld <- getFieldId LEN
|
||||||
|
if fld /= 1
|
||||||
|
then error "getBigUInt"
|
||||||
|
else do
|
||||||
|
len <- varInt
|
||||||
|
bs <- getLazyByteString (fromIntegral len)
|
||||||
|
return $ BigUInt (runGet getByteList bs)
|
||||||
|
|
||||||
|
getByteList :: Get [Word8]
|
||||||
|
getByteList = do
|
||||||
|
fld <- getFieldId LEN
|
||||||
|
if fld /= 1
|
||||||
|
then error "getByteList"
|
||||||
|
else do
|
||||||
|
len <- varInt
|
||||||
|
bs <- getLazyByteString (fromIntegral len)
|
||||||
|
return (L.unpack bs)
|
||||||
|
|
||||||
|
getUnoOpNode :: Get UnoOpNode
|
||||||
|
getUnoOpNode = do
|
||||||
|
SomeNode op arg1 _ _ <- getSomeNode
|
||||||
|
return (UnoOpNode (wordToEnum op) arg1)
|
||||||
|
|
||||||
|
getDuoOpNode :: Get DuoOpNode
|
||||||
|
getDuoOpNode = do
|
||||||
|
SomeNode op arg1 arg2 _ <- getSomeNode
|
||||||
|
return (DuoOpNode (wordToEnum op) arg1 arg2)
|
||||||
|
|
||||||
|
getTresOpNode :: Get TresOpNode
|
||||||
|
getTresOpNode = do
|
||||||
|
SomeNode op arg1 arg2 arg3 <- getSomeNode
|
||||||
|
return (TresOpNode (wordToEnum op) arg1 arg2 arg3)
|
||||||
|
|
||||||
|
wordToEnum :: Enum a => Word32 -> a
|
||||||
|
wordToEnum w = toEnum (fromIntegral w)
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
data SomeNode = SomeNode
|
||||||
|
{ field1 :: Word32
|
||||||
|
, field2 :: Word32
|
||||||
|
, field3 :: Word32
|
||||||
|
, field4 :: Word32
|
||||||
|
}
|
||||||
|
deriving Show
|
||||||
|
|
||||||
|
defaultSomeNode :: SomeNode
|
||||||
|
defaultSomeNode = SomeNode 0 0 0 0
|
||||||
|
|
||||||
|
insert1 :: (Int,Word32) -> SomeNode -> SomeNode
|
||||||
|
insert1 (idx,val) old = case idx of
|
||||||
|
1 -> old { field1 = val }
|
||||||
|
2 -> old { field2 = val }
|
||||||
|
3 -> old { field3 = val }
|
||||||
|
4 -> old { field4 = val }
|
||||||
|
|
||||||
|
insertMany :: [(Int,Word32)] -> SomeNode -> SomeNode
|
||||||
|
insertMany list old = foldl' (flip insert1) old list
|
||||||
|
|
||||||
|
getSomeNode :: Get SomeNode
|
||||||
|
getSomeNode = do
|
||||||
|
len <- varInt
|
||||||
|
bs <- getLazyByteString (fromIntegral len)
|
||||||
|
let list = runGet getRecord bs
|
||||||
|
return $ insertMany list defaultSomeNode
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
-- TODO: refactor this mess
|
||||||
|
|
||||||
|
getMetaData :: Get GraphMetaData
|
||||||
|
getMetaData = do
|
||||||
|
len <- varInt
|
||||||
|
mapping <- getWitnessMapping
|
||||||
|
inputs <- getCircuitInputs
|
||||||
|
return $ GraphMetaData mapping inputs
|
||||||
|
|
||||||
|
getWitnessMapping :: Get WitnessMapping
|
||||||
|
getWitnessMapping = do
|
||||||
|
fld <- getFieldId LEN
|
||||||
|
if fld /= 1
|
||||||
|
then error "getWitnessMapping: expecting field 1"
|
||||||
|
else do
|
||||||
|
len <- varInt
|
||||||
|
bs <- getLazyByteString (fromIntegral len)
|
||||||
|
return $ WitnessMapping (runGet worker bs)
|
||||||
|
where
|
||||||
|
worker :: Get [Word32]
|
||||||
|
worker = isEmpty >>= \b -> if b
|
||||||
|
then return []
|
||||||
|
else (:) <$> varUInt <*> worker
|
||||||
|
|
||||||
|
getCircuitInputs :: Get CircuitInputs
|
||||||
|
getCircuitInputs = worker where
|
||||||
|
|
||||||
|
worker :: Get [(String, SignalDescription)]
|
||||||
|
worker = isEmpty >>= \b -> if b
|
||||||
|
then return []
|
||||||
|
else (:) <$> getSingleInput <*> worker
|
||||||
|
|
||||||
|
getSingleInput :: Get (String, SignalDescription)
|
||||||
|
getSingleInput = do
|
||||||
|
fld <- getFieldId LEN
|
||||||
|
if fld /= 2
|
||||||
|
then error "getCircuitInputs: expecting field 2"
|
||||||
|
else do
|
||||||
|
len <- varInt
|
||||||
|
bs <- getLazyByteString (fromIntegral len)
|
||||||
|
return $ runGet inputHelper bs
|
||||||
|
|
||||||
|
inputHelper = do
|
||||||
|
name <- getName
|
||||||
|
signal <- getSignal
|
||||||
|
return (name,signal)
|
||||||
|
|
||||||
|
getName :: Get String
|
||||||
|
getName = do
|
||||||
|
fld <- getFieldId LEN
|
||||||
|
if fld /= 1
|
||||||
|
then error "getCircuitInputs/getName: expecting field 1"
|
||||||
|
else do
|
||||||
|
len <- varInt
|
||||||
|
bs <- getLazyByteString (fromIntegral len)
|
||||||
|
return (LC.unpack bs)
|
||||||
|
|
||||||
|
getSignal :: Get SignalDescription
|
||||||
|
getSignal = do
|
||||||
|
fld <- getFieldId LEN
|
||||||
|
if fld /= 2
|
||||||
|
then error "getCircuitInputs/getSignal: expecting field 2"
|
||||||
|
else do
|
||||||
|
len <- varInt
|
||||||
|
bs <- getLazyByteString (fromIntegral len)
|
||||||
|
return $ runGet signalHelper bs
|
||||||
|
|
||||||
|
signalHelper = do
|
||||||
|
ofs <- getSignalOffset
|
||||||
|
len <- getSignalLength
|
||||||
|
return $ SignalDescription { signalOffset = ofs , signalLength = len }
|
||||||
|
|
||||||
|
getSignalOffset = do
|
||||||
|
fld <- getFieldId VARINT
|
||||||
|
if fld /= 1
|
||||||
|
then error "getCircuitInputs/getSignalOffset: expecting field 1"
|
||||||
|
else varUInt
|
||||||
|
|
||||||
|
getSignalLength = do
|
||||||
|
fld <- getFieldId VARINT
|
||||||
|
if fld /= 2
|
||||||
|
then error "getCircuitInputs/getSignalLength: expecting field 2"
|
||||||
|
else varUInt
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
-- * protobuf stuff
|
||||||
|
|
||||||
|
-- | There are six wire types: VARINT, I64, LEN, SGROUP, EGROUP, and I32
|
||||||
|
data WireType
|
||||||
|
= VARINT -- ^ used for: int32, int64, uint32, uint64, sint32, sint64, bool, enum
|
||||||
|
| I64 -- ^ used for: fixed64, sfixed64, double
|
||||||
|
| LEN -- ^ used for: string, bytes, embedded messages, packed repeated fields
|
||||||
|
| SGROUP -- ^ used for: group start (deprecated)
|
||||||
|
| EGROUP -- ^ used for: group end (deprecated)
|
||||||
|
| I32 -- ^ fixed32, sfixed32, float
|
||||||
|
deriving (Eq,Show,Enum,Bounded)
|
||||||
|
|
||||||
|
type FieldId = Int
|
||||||
|
|
||||||
|
getFieldId :: WireType -> Get FieldId
|
||||||
|
getFieldId wty = do
|
||||||
|
tag <- getWord8
|
||||||
|
let (fld,wty') = decodeTag tag
|
||||||
|
if wty == wty'
|
||||||
|
then return fld
|
||||||
|
else error "getFieldId: unexpected protobuf wire type"
|
||||||
|
|
||||||
|
decodeTag_ :: Word8 -> FieldId
|
||||||
|
decodeTag_ = fst . decodeTag
|
||||||
|
|
||||||
|
decodeTag :: Word8 -> (FieldId, WireType)
|
||||||
|
decodeTag w = (fld , wty) where
|
||||||
|
fld = fromIntegral (shiftR w 3)
|
||||||
|
wty = toEnum (fromIntegral (w .&. 7))
|
||||||
|
|
||||||
|
-- (index, value) pair
|
||||||
|
getEntry :: Get (Int,Word32)
|
||||||
|
getEntry = do
|
||||||
|
idx <- getFieldId VARINT
|
||||||
|
val <- varUInt
|
||||||
|
return (idx,val)
|
||||||
|
|
||||||
|
-- list of (index, value) pairs
|
||||||
|
getRecord :: Get [(Int,Word32)]
|
||||||
|
getRecord = sort <$> go where
|
||||||
|
go = isEmpty >>= \b -> if b
|
||||||
|
then return []
|
||||||
|
else (:) <$> getEntry <*> go
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
101
src/Semantics.hs
Normal file
101
src/Semantics.hs
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
|
||||||
|
-- | See <https://docs.circom.io/circom-language/basic-operators> for the official
|
||||||
|
-- semantics of the operations
|
||||||
|
|
||||||
|
{-# LANGUAGE StrictData, DeriveFunctor #-}
|
||||||
|
module Semantics where
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
import Data.Bits
|
||||||
|
|
||||||
|
import BN254
|
||||||
|
import Misc
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
data PrimOp a
|
||||||
|
-- unary
|
||||||
|
= Neg a
|
||||||
|
| Id a
|
||||||
|
-- binary
|
||||||
|
| Mul a a
|
||||||
|
| Div a a
|
||||||
|
| Add a a
|
||||||
|
| Sub a a
|
||||||
|
| Pow a a
|
||||||
|
| Idiv a a
|
||||||
|
| Mod a a
|
||||||
|
| Eq_ a a
|
||||||
|
| Neq a a
|
||||||
|
| Lt a a
|
||||||
|
| Gt a a
|
||||||
|
| Leq a a
|
||||||
|
| Geq a a
|
||||||
|
| Land a a
|
||||||
|
| Lor a a
|
||||||
|
| Shl a a
|
||||||
|
| Shr a a
|
||||||
|
| Bor a a
|
||||||
|
| Band a a
|
||||||
|
| Bxor a a
|
||||||
|
-- ternary
|
||||||
|
| Cond a a a
|
||||||
|
deriving (Show,Functor)
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
evalPrimOp :: PrimOp F -> F
|
||||||
|
evalPrimOp prim = case prim of
|
||||||
|
Neg x -> neg x
|
||||||
|
Id x -> x
|
||||||
|
Mul x y -> mul x y
|
||||||
|
Div x y -> BN254.div x y
|
||||||
|
Add x y -> add x y
|
||||||
|
Sub x y -> sub x y
|
||||||
|
Pow x y -> pow x (fromF y)
|
||||||
|
Idiv x y -> if isZero y then 0 else toF $ Prelude.div (fromF x) (fromF y)
|
||||||
|
Mod x y -> if isZero y then 0 else toF $ Prelude.mod (fromF x) (fromF y)
|
||||||
|
Eq_ x y -> fromBool (x == y)
|
||||||
|
Neq x y -> fromBool (x /= y)
|
||||||
|
Lt x y -> fromBool (signedFromF x < signedFromF y)
|
||||||
|
Gt x y -> fromBool (signedFromF x > signedFromF y)
|
||||||
|
Leq x y -> fromBool (signedFromF x <= signedFromF y)
|
||||||
|
Geq x y -> fromBool (signedFromF x >= signedFromF y)
|
||||||
|
Land x y -> fromBool (toBool x && toBool y)
|
||||||
|
Lor x y -> fromBool (toBool x || toBool y)
|
||||||
|
Shl x y -> shiftLeft x (fromF y)
|
||||||
|
Shr x y -> shiftRight x (fromF y)
|
||||||
|
Bor x y -> toF (fromF x .|. fromF y)
|
||||||
|
Band x y -> toF (fromF x .&. fromF y)
|
||||||
|
Bxor x y -> toF (fromF x `xor` fromF y)
|
||||||
|
Cond b x y -> if toBool b then x else y
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
fieldMask :: Integer
|
||||||
|
fieldMask = shiftL 1 fieldBits - 1
|
||||||
|
|
||||||
|
fieldBits :: Int
|
||||||
|
fieldBits = ceilingLog2 fieldPrime
|
||||||
|
|
||||||
|
shiftRight :: F -> Integer -> F
|
||||||
|
shiftRight (MkF x) k = if k < halfPrimePlus1
|
||||||
|
then if k > fromIntegral fieldBits
|
||||||
|
then 0
|
||||||
|
else MkF (shiftR x (fromInteger k))
|
||||||
|
else shiftLeft (MkF x) (fieldPrime - k)
|
||||||
|
|
||||||
|
shiftLeft :: F -> Integer -> F
|
||||||
|
shiftLeft (MkF x) k = if k < halfPrimePlus1
|
||||||
|
then if k > fromIntegral fieldBits
|
||||||
|
then 0
|
||||||
|
else toF (shiftL x (fromIntegral k) .&. fieldMask)
|
||||||
|
else shiftRight (MkF x) (fieldPrime - k)
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
{-
|
||||||
|
notYet :: a
|
||||||
|
notYet = error "not yet implemented"
|
||||||
|
-}
|
||||||
104
src/Witness.hs
Normal file
104
src/Witness.hs
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
|
||||||
|
{-# LANGUAGE StrictData #-}
|
||||||
|
module Witness where
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
import Data.Array
|
||||||
|
import Data.Word
|
||||||
|
|
||||||
|
import qualified Data.Map as Map ; import Data.Map (Map )
|
||||||
|
import qualified Data.IntMap as IntMap ; import Data.IntMap (IntMap)
|
||||||
|
|
||||||
|
import BN254
|
||||||
|
|
||||||
|
import qualified Semantics as S ; import Semantics ( PrimOp , evalPrimOp )
|
||||||
|
import qualified Graph as G ; import Graph ( Graph(..) , Node(..) , UnoOpNode(..) , DuoOpNode(..) , TresOpNode(..) , SignalDescription(..) )
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type Witness = Array Int F
|
||||||
|
|
||||||
|
type Inputs = Map String [Integer]
|
||||||
|
|
||||||
|
witnessCalc :: Inputs -> Graph -> Witness
|
||||||
|
witnessCalc inputs (Graph nodes meta) = witness where
|
||||||
|
nodesArr = listArray (0,length nodes-1) nodes
|
||||||
|
rawWitness = evaluateNodes rawInputs nodesArr
|
||||||
|
rawInputs = convertInputs (G.inputSignals meta) inputs
|
||||||
|
mapping_ = G.fromWitnessMapping (G.witnessMapping meta)
|
||||||
|
wtnslen = length mapping_
|
||||||
|
mapping = listArray (0,wtnslen-1) mapping_
|
||||||
|
witness = listArray (0,wtnslen-1) $ [ rawWitness!(fromIntegral (mapping!i)) | i<-[0..wtnslen-1] ]
|
||||||
|
|
||||||
|
convertInputs :: [(String,SignalDescription)] -> Map String [Integer] -> IntMap F
|
||||||
|
convertInputs descTable inputTable = IntMap.fromList $ concatMap f descTable where
|
||||||
|
f :: (String,SignalDescription) -> [(Int,F)]
|
||||||
|
f (name,desc) = case Map.lookup name inputTable of
|
||||||
|
Nothing -> error $ "input signal `" ++ name ++ "` not found in the given inputs!"
|
||||||
|
Just values -> if length values /= fromIntegral (signalLength desc)
|
||||||
|
then error $ "input signal `" ++ name ++ "` has incorrect size"
|
||||||
|
else let ofs = fromIntegral (signalOffset desc) :: Int
|
||||||
|
in (0,1) : zip [ofs..] (map toF values)
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type RawInputs = IntMap F
|
||||||
|
|
||||||
|
evaluateNodes :: RawInputs -> Array Int Node -> Array Int F
|
||||||
|
evaluateNodes inputs nodes = witness where
|
||||||
|
|
||||||
|
(a,b) = bounds nodes
|
||||||
|
witness = array (0,b-a) $ [ (i-a, worker i) | i<-[a..b] ]
|
||||||
|
|
||||||
|
lkp :: Word32 -> F
|
||||||
|
lkp i = witness!(fromIntegral i)
|
||||||
|
|
||||||
|
worker i = case nodes!i of
|
||||||
|
AnInputNode (G.InputNode idx) -> getInputValue idx
|
||||||
|
AConstantNode (G.ConstantNode big) -> fromBigUInt big
|
||||||
|
AnUnoOpNode unoOpNode -> evalPrimOp (fmap lkp $ unoToPrimOp unoOpNode )
|
||||||
|
ADuoOpNode duoOpNode -> evalPrimOp (fmap lkp $ duoToPrimOp duoOpNode )
|
||||||
|
ATresOpNode tresOpNode -> evalPrimOp (fmap lkp $ tresToPrimOp tresOpNode)
|
||||||
|
|
||||||
|
getInputValue j = case IntMap.lookup (fromIntegral j) inputs of
|
||||||
|
Just y -> y
|
||||||
|
Nothing -> error ("input value not found at index " ++ show j)
|
||||||
|
|
||||||
|
fromBigUInt (G.BigUInt bs) = fromBytesLE bs
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
unoToPrimOp :: UnoOpNode -> PrimOp Word32
|
||||||
|
unoToPrimOp (UnoOpNode op arg1) = case op of
|
||||||
|
G.Neg -> S.Neg arg1
|
||||||
|
G.Id -> S.Id arg1
|
||||||
|
|
||||||
|
duoToPrimOp :: DuoOpNode -> PrimOp Word32
|
||||||
|
duoToPrimOp (DuoOpNode op arg1 arg2) = case op of
|
||||||
|
G.Mul -> S.Mul arg1 arg2
|
||||||
|
G.Div -> S.Div arg1 arg2
|
||||||
|
G.Add -> S.Add arg1 arg2
|
||||||
|
G.Sub -> S.Sub arg1 arg2
|
||||||
|
G.Pow -> S.Pow arg1 arg2
|
||||||
|
G.Idiv -> S.Idiv arg1 arg2
|
||||||
|
G.Mod -> S.Mod arg1 arg2
|
||||||
|
G.Eq_ -> S.Eq_ arg1 arg2
|
||||||
|
G.Neq -> S.Neq arg1 arg2
|
||||||
|
G.Lt -> S.Lt arg1 arg2
|
||||||
|
G.Gt -> S.Gt arg1 arg2
|
||||||
|
G.Leq -> S.Leq arg1 arg2
|
||||||
|
G.Geq -> S.Geq arg1 arg2
|
||||||
|
G.Land -> S.Land arg1 arg2
|
||||||
|
G.Lor -> S.Lor arg1 arg2
|
||||||
|
G.Shl -> S.Shl arg1 arg2
|
||||||
|
G.Shr -> S.Shr arg1 arg2
|
||||||
|
G.Bor -> S.Bor arg1 arg2
|
||||||
|
G.Band -> S.Band arg1 arg2
|
||||||
|
G.Bxor -> S.Bxor arg1 arg2
|
||||||
|
|
||||||
|
tresToPrimOp :: TresOpNode -> PrimOp Word32
|
||||||
|
tresToPrimOp (TresOpNode op arg1 arg2 arg3) = case op of
|
||||||
|
G.TernCond -> S.Cond arg1 arg2 arg3
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
Loading…
x
Reference in New Issue
Block a user