mirror of
https://github.com/logos-storage/circom-witnessgen.git
synced 2026-01-02 13:03:09 +00:00
initial import the Haskell WIP code (parser + interpreter works on a simple example)
This commit is contained in:
parent
0fc22c2a27
commit
de29a073da
6
.gitignore
vendored
Normal file
6
.gitignore
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
.DS_Store
|
||||
*.hi
|
||||
*.o
|
||||
tmp/
|
||||
build/
|
||||
.ghc.environment.*
|
||||
55
README.md
Normal file
55
README.md
Normal file
@ -0,0 +1,55 @@
|
||||
|
||||
Circom witness generators
|
||||
-------------------------
|
||||
|
||||
This piece of software takes the computation graph files generated by
|
||||
[`circom-witnesscalc`](https://github.com/iden3/circom-witnesscalc),
|
||||
and either interprets or compiles them to various algebra backends.
|
||||
|
||||
### Implementation status
|
||||
|
||||
Compiler (in Haskell):
|
||||
|
||||
- [x] parsing the graph files
|
||||
- [x] naive interpreter
|
||||
- [ ] constantine backend
|
||||
- [ ] zikkurat backend
|
||||
- [ ] arkworks backend
|
||||
|
||||
Nim witness generator:
|
||||
|
||||
- [x] parsing the graph files
|
||||
- [ ] generating the witness
|
||||
- [ ] proper error handling
|
||||
|
||||
### Testing & correctness
|
||||
|
||||
I haven't yet done any proper testing, apart from "works for our purposes".
|
||||
|
||||
### Circuit optimizations
|
||||
|
||||
NOTE: you _have to_ run `circom` with the `--O2` options, otherwise the
|
||||
witness format will be most probably incompatible with the the one generated
|
||||
by `circom-witnesscalc`.
|
||||
|
||||
### Graph file format
|
||||
|
||||
`circom-witnesscalc` produces binary files encoding a computation graph.
|
||||
|
||||
This has the following format:
|
||||
|
||||
- magic header: "wtns.graph.001" (14 bytes)
|
||||
- number of nodes (8 bytes little endian)
|
||||
- list of nodes, in protobuf serialized format. Each prefixed by a `varint` length
|
||||
- `GraphMetaData`, in protobuf
|
||||
- 8 bytes offset (yes, _at the very end_...), pointing to the start of `GraphMetaData`
|
||||
|
||||
Node format:
|
||||
|
||||
- varint length prefix
|
||||
- protobuf tag-byte (lower 3 bits are `0x02`, upper bits are node type 1..5)
|
||||
- varint length
|
||||
- several record fields
|
||||
- protobuf tag-byte (lower 3 bits are 0x00, upper bits are field index 1,2,3,4)
|
||||
- value is varint word32, except for ConstantNode when it's a little-endian bytes
|
||||
(wrapped a few times, because protobuf is being protobuf...)
|
||||
191
src/BN254.hs
Normal file
191
src/BN254.hs
Normal file
@ -0,0 +1,191 @@
|
||||
|
||||
-- | The BN254 scalar field
|
||||
|
||||
{-# LANGUAGE Strict, BangPatterns #-}
|
||||
module BN254 where
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
import Prelude hiding (div)
|
||||
import qualified Prelude
|
||||
|
||||
import Data.Bits
|
||||
import Data.Word
|
||||
import Data.Ratio
|
||||
|
||||
import System.Random
|
||||
import Text.Printf
|
||||
|
||||
import Misc
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
fieldPrime :: Integer
|
||||
fieldPrime = 21888242871839275222246405745257275088548364400416034343698204186575808495617
|
||||
|
||||
modP :: Integer -> Integer
|
||||
modP x = mod x fieldPrime
|
||||
|
||||
halfPrimePlus1 :: Integer
|
||||
halfPrimePlus1 = 1 + Prelude.div fieldPrime 2
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
newtype F
|
||||
= MkF Integer
|
||||
deriving (Eq,Show)
|
||||
|
||||
fromF :: F -> Integer
|
||||
fromF (MkF x) = x
|
||||
|
||||
-- from the circom docs: @val(z) = z-p if p/2 +1 <= z < p@
|
||||
signedFromF :: F -> Integer
|
||||
signedFromF (MkF x) = if x >= halfPrimePlus1 then x - fieldPrime else x
|
||||
|
||||
toF :: Integer -> F
|
||||
toF = MkF . modP
|
||||
|
||||
isZero :: F -> Bool
|
||||
isZero (MkF x) = (x == 0)
|
||||
|
||||
fromBool :: Bool -> F
|
||||
fromBool b = if b then 1 else 0
|
||||
|
||||
toBool :: F -> Bool
|
||||
toBool = not . isZero
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
neg :: F -> F
|
||||
neg (MkF x) = toF (negate x)
|
||||
|
||||
add :: F -> F -> F
|
||||
add (MkF x) (MkF y) = toF (x+y)
|
||||
|
||||
sub :: F -> F -> F
|
||||
sub (MkF x) (MkF y) = toF (x-y)
|
||||
|
||||
mul :: F -> F -> F
|
||||
mul (MkF x) (MkF y) = toF (x*y)
|
||||
|
||||
instance Num F where
|
||||
fromInteger = toF
|
||||
negate = neg
|
||||
(+) = add
|
||||
(-) = sub
|
||||
(*) = mul
|
||||
abs = id
|
||||
signum _ = toF 1
|
||||
|
||||
square :: F -> F
|
||||
square x = x*x
|
||||
|
||||
rndF :: IO F
|
||||
rndF = MkF <$> randomRIO (0,fieldPrime-1)
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
pow :: F -> Integer -> F
|
||||
pow x0 exponent
|
||||
| exponent < 0 = error "power: expecting positive exponent"
|
||||
| otherwise = go 1 x0 exponent
|
||||
where
|
||||
go !acc _ 0 = acc
|
||||
go !acc s e = go acc' s' (shiftR e 1) where
|
||||
s' = s*s
|
||||
acc' = if e .&. 1 == 0 then acc else acc*s
|
||||
|
||||
invNaive :: F -> F
|
||||
invNaive x = pow x (fieldPrime - 2)
|
||||
|
||||
divNaive :: F -> F -> F
|
||||
divNaive x y = x * invNaive y
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
instance Fractional F where
|
||||
fromRational q = fromInteger (numerator q) / fromInteger (denominator q)
|
||||
recip = inv
|
||||
(/) = div
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
fromBytesLE :: [Word8] -> F
|
||||
fromBytesLE = toF . integerFromBytesLE
|
||||
|
||||
integerFromBytesLE :: [Word8] -> Integer
|
||||
integerFromBytesLE = go where
|
||||
go [] = 0
|
||||
go (b:bs) = fromIntegral b + (shiftL (go bs) 8)
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
instance ShowHex Integer where showHex = printf "0x%x"
|
||||
instance ShowHex F where showHex (MkF x) = showHex x
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
-- | Inversion (using Euclid's algorithm)
|
||||
inv :: F -> F
|
||||
inv (MkF a)
|
||||
| a == 0 = 0 -- error "field inverse of zero (generic prime)"
|
||||
| otherwise = MkF (euclid 1 0 a fieldPrime)
|
||||
|
||||
-- | Division via Euclid's algorithm
|
||||
div :: F -> F -> F
|
||||
div (MkF a) (MkF b)
|
||||
| b == 0 = 0 -- error "field division by zero (generic prime)"
|
||||
| otherwise = MkF (euclid a 0 b fieldPrime)
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- * Euclidean algorithm
|
||||
|
||||
-- | Extended binary Euclidean algorithm
|
||||
euclid :: Integer -> Integer -> Integer -> Integer -> Integer
|
||||
euclid !x1 !x2 !u !v = go x1 x2 u v where
|
||||
|
||||
p = fieldPrime
|
||||
|
||||
halfp1 = shiftR (p+1) 1
|
||||
|
||||
modp :: Integer -> Integer
|
||||
modp n = mod n p
|
||||
|
||||
-- Inverse using the binary Euclidean algorithm
|
||||
euclid :: Integer -> Integer
|
||||
euclid a
|
||||
| a == 0 = 0
|
||||
| otherwise = go 1 0 a p
|
||||
|
||||
go :: Integer -> Integer -> Integer -> Integer -> Integer
|
||||
go !x1 !x2 !u !v
|
||||
| u==1 = x1
|
||||
| v==1 = x2
|
||||
| otherwise = stepU x1 x2 u v
|
||||
|
||||
stepU :: Integer -> Integer -> Integer -> Integer -> Integer
|
||||
stepU !x1 !x2 !u !v = if even u
|
||||
then let u' = shiftR u 1
|
||||
x1' = if even x1 then shiftR x1 1 else shiftR x1 1 + halfp1
|
||||
in stepU x1' x2 u' v
|
||||
else stepV x1 x2 u v
|
||||
|
||||
stepV :: Integer -> Integer -> Integer -> Integer -> Integer
|
||||
stepV !x1 !x2 !u !v = if even v
|
||||
then let v' = shiftR v 1
|
||||
x2' = if even x2 then shiftR x2 1 else shiftR x2 1 + halfp1
|
||||
in stepV x1 x2' u v'
|
||||
else final x1 x2 u v
|
||||
|
||||
final :: Integer -> Integer -> Integer -> Integer -> Integer
|
||||
final !x1 !x2 !u !v = if u>=v
|
||||
|
||||
then let u' = u-v
|
||||
x1' = if x1 >= x2 then modp (x1-x2) else modp (x1+p-x2)
|
||||
in go x1' x2 u' v
|
||||
|
||||
else let v' = v-u
|
||||
x2' = if x2 >= x1 then modp (x2-x1) else modp (x2+p-x1)
|
||||
in go x1 x2' u v'
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
104
src/Graph.hs
Normal file
104
src/Graph.hs
Normal file
@ -0,0 +1,104 @@
|
||||
|
||||
{-# LANGUAGE StrictData #-}
|
||||
module Graph where
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
import Text.Printf
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
import Data.Word
|
||||
|
||||
data Graph = Graph
|
||||
{ graphNodes :: [Node]
|
||||
, graphMeta :: GraphMetaData
|
||||
}
|
||||
deriving Show
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
-- | Unary operations
|
||||
data UnoOp
|
||||
= Neg -- ^ @= 0@
|
||||
| Id -- ^ @= 1@
|
||||
deriving (Eq,Enum,Bounded,Show)
|
||||
|
||||
data DuoOp
|
||||
= Mul -- ^ @= 0@
|
||||
| Div -- ^ @= 1@
|
||||
| Add -- ^ @= 2@
|
||||
| Sub -- ^ @= 3@
|
||||
| Pow -- ^ @= 4@
|
||||
| Idiv -- ^ @= 5@
|
||||
| Mod -- ^ @= 6@
|
||||
| Eq_ -- ^ @= 7@
|
||||
| Neq -- ^ @= 8@
|
||||
| Lt -- ^ @= 9@
|
||||
| Gt -- ^ @= 10@
|
||||
| Leq -- ^ @= 11@
|
||||
| Geq -- ^ @= 12@
|
||||
| Land -- ^ @= 13@
|
||||
| Lor -- ^ @= 14@
|
||||
| Shl -- ^ @= 15@
|
||||
| Shr -- ^ @= 16@
|
||||
| Bor -- ^ @= 17@
|
||||
| Band -- ^ @= 18@
|
||||
| Bxor -- ^ @= 19@
|
||||
deriving (Eq,Enum,Bounded,Show)
|
||||
|
||||
data TresOp
|
||||
= TernCond -- ^ @= 0@
|
||||
deriving (Eq,Enum,Bounded,Show)
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
newtype BigUInt
|
||||
= BigUInt [Word8] -- ^ little endian
|
||||
|
||||
showBigUInt :: BigUInt -> String
|
||||
showBigUInt (BigUInt bytes) = "0x" ++ concatMap f (reverse bytes) where
|
||||
f :: Word8 -> String
|
||||
f = printf "%02x"
|
||||
|
||||
instance Show BigUInt where show = showBigUInt
|
||||
|
||||
newtype InputNode
|
||||
= InputNode Word32
|
||||
deriving (Show)
|
||||
|
||||
newtype ConstantNode
|
||||
= ConstantNode BigUInt
|
||||
deriving (Show)
|
||||
|
||||
data UnoOpNode = UnoOpNode !UnoOp !Word32 deriving (Show)
|
||||
data DuoOpNode = DuoOpNode !DuoOp !Word32 !Word32 deriving (Show)
|
||||
data TresOpNode = TresOpNode !TresOp !Word32 !Word32 !Word32 deriving (Show)
|
||||
|
||||
data Node
|
||||
= AnInputNode InputNode -- @= 1@
|
||||
| AConstantNode ConstantNode -- @= 2@
|
||||
| AnUnoOpNode UnoOpNode -- @= 3@
|
||||
| ADuoOpNode DuoOpNode -- @= 4@
|
||||
| ATresOpNode TresOpNode -- @= 5@
|
||||
deriving (Show)
|
||||
|
||||
data SignalDescription = SignalDescription
|
||||
{ signalOffset :: !Word32
|
||||
, signalLength :: !Word32
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
newtype WitnessMapping
|
||||
= WitnessMapping { fromWitnessMapping :: [Word32] }
|
||||
deriving (Show)
|
||||
|
||||
type CircuitInputs = [(String, SignalDescription)]
|
||||
|
||||
data GraphMetaData = GraphMetaData
|
||||
{ witnessMapping :: WitnessMapping
|
||||
, inputSignals :: CircuitInputs
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
34
src/Main.hs
Normal file
34
src/Main.hs
Normal file
@ -0,0 +1,34 @@
|
||||
|
||||
module Main where
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
import Data.Map (Map)
|
||||
import qualified Data.Map as Map
|
||||
|
||||
import Witness
|
||||
import Parser
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
(~>) :: String -> a -> (String, a)
|
||||
(~>) = (,)
|
||||
|
||||
infix 2 ~>
|
||||
|
||||
testInputs :: Map String [Integer]
|
||||
testInputs = Map.fromList
|
||||
[ "a" ~> [0xff01]
|
||||
, "b" ~> [0xff02]
|
||||
]
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
main :: IO ()
|
||||
main = do
|
||||
Right graph <- parseGraphFile "../tmp/graph2.bin"
|
||||
putStrLn ""
|
||||
print graph
|
||||
let wtns = witnessCalc testInputs graph
|
||||
putStrLn ""
|
||||
print wtns
|
||||
33
src/Misc.hs
Normal file
33
src/Misc.hs
Normal file
@ -0,0 +1,33 @@
|
||||
|
||||
{-# LANGUAGE Strict #-}
|
||||
module Misc where
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
import Data.Bits
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
class ShowHex a where
|
||||
showHex :: a -> String
|
||||
|
||||
printHex :: ShowHex a => a -> IO ()
|
||||
printHex x = putStrLn (showHex x)
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- Integer logarithm
|
||||
|
||||
-- | Largest integer @k@ such that @2^k@ is smaller or equal to @n@
|
||||
integerLog2 :: Integer -> Int
|
||||
integerLog2 n = go n where
|
||||
go 0 = -1
|
||||
go k = 1 + go (shiftR k 1)
|
||||
|
||||
-- | Smallest integer @k@ such that @2^k@ is larger or equal to @n@
|
||||
ceilingLog2 :: Integer -> Int
|
||||
ceilingLog2 0 = 0
|
||||
ceilingLog2 n = 1 + go (n-1) where
|
||||
go 0 = -1
|
||||
go k = 1 + go (shiftR k 1)
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
345
src/Parser.hs
Normal file
345
src/Parser.hs
Normal file
@ -0,0 +1,345 @@
|
||||
|
||||
-- | Parsing the graph binary format
|
||||
|
||||
{-# LANGUAGE Strict, PackageImports, BangPatterns #-}
|
||||
module Parser where
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
import Data.Bits
|
||||
import Data.Word
|
||||
import Data.List
|
||||
import Data.Ord
|
||||
|
||||
import Control.Monad
|
||||
import Control.Applicative
|
||||
import System.IO
|
||||
|
||||
import Data.ByteString.Lazy (ByteString)
|
||||
import qualified Data.ByteString.Lazy as L
|
||||
import qualified Data.ByteString.Lazy.Char8 as LC
|
||||
|
||||
import "binary" Data.Binary.Get
|
||||
import "binary" Data.Binary.Builder as Builder
|
||||
|
||||
import Graph
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
{-
|
||||
test :: IO ()
|
||||
test = do
|
||||
Right graph <- parseGraphFile "../tmp/graph.bin"
|
||||
print graph
|
||||
-}
|
||||
|
||||
{-
|
||||
nodeEx1 = 0x06 : nodeEx1' :: [Word8]
|
||||
nodeEx2 = 0x08 : nodeEx2' :: [Word8]
|
||||
|
||||
nodeEx1' = 0x22 : 0x04 : duoNodeEx1 :: [Word8]
|
||||
nodeEx2' = 0x22 : 0x06 : duoNodeEx2 :: [Word8]
|
||||
|
||||
duoNodeEx1 = [ 0x10 , 0x05 , 0x18 , 0x05 ] :: [Word8]
|
||||
duoNodeEx2 = [ 0x08 , 0x02 , 0x10 , 0x03 , 0x18 , 0x04 ] :: [Word8]
|
||||
-}
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
type Msg = String
|
||||
|
||||
parseGraphFile :: FilePath -> IO (Either Msg Graph)
|
||||
parseGraphFile fname = do
|
||||
h <- openBinaryFile fname ReadMode
|
||||
ei <- readGraphFile h
|
||||
hClose h
|
||||
return ei
|
||||
|
||||
hGetBytes :: Handle -> Int -> IO ByteString
|
||||
hGetBytes h n = L.hGet h (fromIntegral n)
|
||||
|
||||
hSeekInt :: Handle -> Int -> IO ()
|
||||
hSeekInt h ofs = hSeek h AbsoluteSeek (fromIntegral ofs)
|
||||
|
||||
readGraphFile :: Handle -> IO (Either Msg Graph)
|
||||
readGraphFile h = do
|
||||
flen <- (fromIntegral :: Integer -> Int) <$> hFileSize h
|
||||
magic <- hGetBytes h (length magicHeader)
|
||||
if magic /= LC.pack magicHeader
|
||||
then return $ Left "magic header not found or invalid"
|
||||
else do
|
||||
hSeekInt h (flen - 8)
|
||||
offset <- (fromIntegral . runGet getWord64le) <$> hGetBytes h 8
|
||||
putStrLn $ "metadata offset = " ++ show offset
|
||||
if (offset >= flen) || (offset <= 18)
|
||||
then return $ Left "invalid final `graphMetaData` offset bytes"
|
||||
else do
|
||||
hSeekInt h (length magicHeader)
|
||||
part1 <- hGetBytes h (offset - length magicHeader)
|
||||
part2 <- hGetBytes h (flen - offset - 8)
|
||||
return (Right $ Graph (parseNodes part1) (parseMeta part2))
|
||||
|
||||
magicHeader :: String
|
||||
magicHeader = "wtns.graph.001"
|
||||
|
||||
parseNodes :: ByteString -> [Node]
|
||||
parseNodes = runGet getNodes
|
||||
|
||||
parseMeta :: ByteString -> GraphMetaData
|
||||
parseMeta = runGet getMetaData
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
varInt' :: Get Word64
|
||||
varInt' = go 0 where
|
||||
go !cnt = if cnt >= 8 then return 0 else do
|
||||
w <- getWord8
|
||||
if (w < 128)
|
||||
then return (fromIntegral w)
|
||||
else do
|
||||
let x = fromIntegral (w .&. 127)
|
||||
y <- go (cnt+1)
|
||||
return (x + 128*y)
|
||||
|
||||
varInt :: Get Int
|
||||
varInt = fromIntegral <$> varInt'
|
||||
|
||||
varUInt :: Get Word32
|
||||
varUInt = fromIntegral <$> varInt'
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
getNodes :: Get [Node]
|
||||
getNodes = do
|
||||
n <- getWord64le
|
||||
replicateM (fromIntegral n) getNode
|
||||
|
||||
-- | with varint length prefix
|
||||
getNode :: Get Node
|
||||
getNode = do
|
||||
len <- varInt
|
||||
bs <- getLazyByteString (fromIntegral len)
|
||||
return (runGet getNode' bs)
|
||||
|
||||
-- | without varint length prefix
|
||||
getNode' :: Get Node
|
||||
getNode' = do
|
||||
nodetype <- getFieldId LEN
|
||||
case nodetype of
|
||||
1 -> AnInputNode <$> getInputNode
|
||||
2 -> AConstantNode <$> getConstantNode
|
||||
3 -> AnUnoOpNode <$> getUnoOpNode
|
||||
4 -> ADuoOpNode <$> getDuoOpNode
|
||||
5 -> ATresOpNode <$> getTresOpNode
|
||||
_ -> error "unexpected node type"
|
||||
|
||||
getInputNode :: Get InputNode
|
||||
getInputNode = do
|
||||
SomeNode idx _ _ _ <- getSomeNode
|
||||
return (InputNode idx)
|
||||
|
||||
getConstantNode :: Get ConstantNode
|
||||
getConstantNode = do
|
||||
len <- varInt
|
||||
bs <- getLazyByteString (fromIntegral len)
|
||||
return $ ConstantNode (runGet getBigUInt bs)
|
||||
|
||||
getBigUInt :: Get BigUInt
|
||||
getBigUInt = do
|
||||
fld <- getFieldId LEN
|
||||
if fld /= 1
|
||||
then error "getBigUInt"
|
||||
else do
|
||||
len <- varInt
|
||||
bs <- getLazyByteString (fromIntegral len)
|
||||
return $ BigUInt (runGet getByteList bs)
|
||||
|
||||
getByteList :: Get [Word8]
|
||||
getByteList = do
|
||||
fld <- getFieldId LEN
|
||||
if fld /= 1
|
||||
then error "getByteList"
|
||||
else do
|
||||
len <- varInt
|
||||
bs <- getLazyByteString (fromIntegral len)
|
||||
return (L.unpack bs)
|
||||
|
||||
getUnoOpNode :: Get UnoOpNode
|
||||
getUnoOpNode = do
|
||||
SomeNode op arg1 _ _ <- getSomeNode
|
||||
return (UnoOpNode (wordToEnum op) arg1)
|
||||
|
||||
getDuoOpNode :: Get DuoOpNode
|
||||
getDuoOpNode = do
|
||||
SomeNode op arg1 arg2 _ <- getSomeNode
|
||||
return (DuoOpNode (wordToEnum op) arg1 arg2)
|
||||
|
||||
getTresOpNode :: Get TresOpNode
|
||||
getTresOpNode = do
|
||||
SomeNode op arg1 arg2 arg3 <- getSomeNode
|
||||
return (TresOpNode (wordToEnum op) arg1 arg2 arg3)
|
||||
|
||||
wordToEnum :: Enum a => Word32 -> a
|
||||
wordToEnum w = toEnum (fromIntegral w)
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
data SomeNode = SomeNode
|
||||
{ field1 :: Word32
|
||||
, field2 :: Word32
|
||||
, field3 :: Word32
|
||||
, field4 :: Word32
|
||||
}
|
||||
deriving Show
|
||||
|
||||
defaultSomeNode :: SomeNode
|
||||
defaultSomeNode = SomeNode 0 0 0 0
|
||||
|
||||
insert1 :: (Int,Word32) -> SomeNode -> SomeNode
|
||||
insert1 (idx,val) old = case idx of
|
||||
1 -> old { field1 = val }
|
||||
2 -> old { field2 = val }
|
||||
3 -> old { field3 = val }
|
||||
4 -> old { field4 = val }
|
||||
|
||||
insertMany :: [(Int,Word32)] -> SomeNode -> SomeNode
|
||||
insertMany list old = foldl' (flip insert1) old list
|
||||
|
||||
getSomeNode :: Get SomeNode
|
||||
getSomeNode = do
|
||||
len <- varInt
|
||||
bs <- getLazyByteString (fromIntegral len)
|
||||
let list = runGet getRecord bs
|
||||
return $ insertMany list defaultSomeNode
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- TODO: refactor this mess
|
||||
|
||||
getMetaData :: Get GraphMetaData
|
||||
getMetaData = do
|
||||
len <- varInt
|
||||
mapping <- getWitnessMapping
|
||||
inputs <- getCircuitInputs
|
||||
return $ GraphMetaData mapping inputs
|
||||
|
||||
getWitnessMapping :: Get WitnessMapping
|
||||
getWitnessMapping = do
|
||||
fld <- getFieldId LEN
|
||||
if fld /= 1
|
||||
then error "getWitnessMapping: expecting field 1"
|
||||
else do
|
||||
len <- varInt
|
||||
bs <- getLazyByteString (fromIntegral len)
|
||||
return $ WitnessMapping (runGet worker bs)
|
||||
where
|
||||
worker :: Get [Word32]
|
||||
worker = isEmpty >>= \b -> if b
|
||||
then return []
|
||||
else (:) <$> varUInt <*> worker
|
||||
|
||||
getCircuitInputs :: Get CircuitInputs
|
||||
getCircuitInputs = worker where
|
||||
|
||||
worker :: Get [(String, SignalDescription)]
|
||||
worker = isEmpty >>= \b -> if b
|
||||
then return []
|
||||
else (:) <$> getSingleInput <*> worker
|
||||
|
||||
getSingleInput :: Get (String, SignalDescription)
|
||||
getSingleInput = do
|
||||
fld <- getFieldId LEN
|
||||
if fld /= 2
|
||||
then error "getCircuitInputs: expecting field 2"
|
||||
else do
|
||||
len <- varInt
|
||||
bs <- getLazyByteString (fromIntegral len)
|
||||
return $ runGet inputHelper bs
|
||||
|
||||
inputHelper = do
|
||||
name <- getName
|
||||
signal <- getSignal
|
||||
return (name,signal)
|
||||
|
||||
getName :: Get String
|
||||
getName = do
|
||||
fld <- getFieldId LEN
|
||||
if fld /= 1
|
||||
then error "getCircuitInputs/getName: expecting field 1"
|
||||
else do
|
||||
len <- varInt
|
||||
bs <- getLazyByteString (fromIntegral len)
|
||||
return (LC.unpack bs)
|
||||
|
||||
getSignal :: Get SignalDescription
|
||||
getSignal = do
|
||||
fld <- getFieldId LEN
|
||||
if fld /= 2
|
||||
then error "getCircuitInputs/getSignal: expecting field 2"
|
||||
else do
|
||||
len <- varInt
|
||||
bs <- getLazyByteString (fromIntegral len)
|
||||
return $ runGet signalHelper bs
|
||||
|
||||
signalHelper = do
|
||||
ofs <- getSignalOffset
|
||||
len <- getSignalLength
|
||||
return $ SignalDescription { signalOffset = ofs , signalLength = len }
|
||||
|
||||
getSignalOffset = do
|
||||
fld <- getFieldId VARINT
|
||||
if fld /= 1
|
||||
then error "getCircuitInputs/getSignalOffset: expecting field 1"
|
||||
else varUInt
|
||||
|
||||
getSignalLength = do
|
||||
fld <- getFieldId VARINT
|
||||
if fld /= 2
|
||||
then error "getCircuitInputs/getSignalLength: expecting field 2"
|
||||
else varUInt
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- * protobuf stuff
|
||||
|
||||
-- | There are six wire types: VARINT, I64, LEN, SGROUP, EGROUP, and I32
|
||||
data WireType
|
||||
= VARINT -- ^ used for: int32, int64, uint32, uint64, sint32, sint64, bool, enum
|
||||
| I64 -- ^ used for: fixed64, sfixed64, double
|
||||
| LEN -- ^ used for: string, bytes, embedded messages, packed repeated fields
|
||||
| SGROUP -- ^ used for: group start (deprecated)
|
||||
| EGROUP -- ^ used for: group end (deprecated)
|
||||
| I32 -- ^ fixed32, sfixed32, float
|
||||
deriving (Eq,Show,Enum,Bounded)
|
||||
|
||||
type FieldId = Int
|
||||
|
||||
getFieldId :: WireType -> Get FieldId
|
||||
getFieldId wty = do
|
||||
tag <- getWord8
|
||||
let (fld,wty') = decodeTag tag
|
||||
if wty == wty'
|
||||
then return fld
|
||||
else error "getFieldId: unexpected protobuf wire type"
|
||||
|
||||
decodeTag_ :: Word8 -> FieldId
|
||||
decodeTag_ = fst . decodeTag
|
||||
|
||||
decodeTag :: Word8 -> (FieldId, WireType)
|
||||
decodeTag w = (fld , wty) where
|
||||
fld = fromIntegral (shiftR w 3)
|
||||
wty = toEnum (fromIntegral (w .&. 7))
|
||||
|
||||
-- (index, value) pair
|
||||
getEntry :: Get (Int,Word32)
|
||||
getEntry = do
|
||||
idx <- getFieldId VARINT
|
||||
val <- varUInt
|
||||
return (idx,val)
|
||||
|
||||
-- list of (index, value) pairs
|
||||
getRecord :: Get [(Int,Word32)]
|
||||
getRecord = sort <$> go where
|
||||
go = isEmpty >>= \b -> if b
|
||||
then return []
|
||||
else (:) <$> getEntry <*> go
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
101
src/Semantics.hs
Normal file
101
src/Semantics.hs
Normal file
@ -0,0 +1,101 @@
|
||||
|
||||
-- | See <https://docs.circom.io/circom-language/basic-operators> for the official
|
||||
-- semantics of the operations
|
||||
|
||||
{-# LANGUAGE StrictData, DeriveFunctor #-}
|
||||
module Semantics where
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
import Data.Bits
|
||||
|
||||
import BN254
|
||||
import Misc
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
data PrimOp a
|
||||
-- unary
|
||||
= Neg a
|
||||
| Id a
|
||||
-- binary
|
||||
| Mul a a
|
||||
| Div a a
|
||||
| Add a a
|
||||
| Sub a a
|
||||
| Pow a a
|
||||
| Idiv a a
|
||||
| Mod a a
|
||||
| Eq_ a a
|
||||
| Neq a a
|
||||
| Lt a a
|
||||
| Gt a a
|
||||
| Leq a a
|
||||
| Geq a a
|
||||
| Land a a
|
||||
| Lor a a
|
||||
| Shl a a
|
||||
| Shr a a
|
||||
| Bor a a
|
||||
| Band a a
|
||||
| Bxor a a
|
||||
-- ternary
|
||||
| Cond a a a
|
||||
deriving (Show,Functor)
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
evalPrimOp :: PrimOp F -> F
|
||||
evalPrimOp prim = case prim of
|
||||
Neg x -> neg x
|
||||
Id x -> x
|
||||
Mul x y -> mul x y
|
||||
Div x y -> BN254.div x y
|
||||
Add x y -> add x y
|
||||
Sub x y -> sub x y
|
||||
Pow x y -> pow x (fromF y)
|
||||
Idiv x y -> if isZero y then 0 else toF $ Prelude.div (fromF x) (fromF y)
|
||||
Mod x y -> if isZero y then 0 else toF $ Prelude.mod (fromF x) (fromF y)
|
||||
Eq_ x y -> fromBool (x == y)
|
||||
Neq x y -> fromBool (x /= y)
|
||||
Lt x y -> fromBool (signedFromF x < signedFromF y)
|
||||
Gt x y -> fromBool (signedFromF x > signedFromF y)
|
||||
Leq x y -> fromBool (signedFromF x <= signedFromF y)
|
||||
Geq x y -> fromBool (signedFromF x >= signedFromF y)
|
||||
Land x y -> fromBool (toBool x && toBool y)
|
||||
Lor x y -> fromBool (toBool x || toBool y)
|
||||
Shl x y -> shiftLeft x (fromF y)
|
||||
Shr x y -> shiftRight x (fromF y)
|
||||
Bor x y -> toF (fromF x .|. fromF y)
|
||||
Band x y -> toF (fromF x .&. fromF y)
|
||||
Bxor x y -> toF (fromF x `xor` fromF y)
|
||||
Cond b x y -> if toBool b then x else y
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
fieldMask :: Integer
|
||||
fieldMask = shiftL 1 fieldBits - 1
|
||||
|
||||
fieldBits :: Int
|
||||
fieldBits = ceilingLog2 fieldPrime
|
||||
|
||||
shiftRight :: F -> Integer -> F
|
||||
shiftRight (MkF x) k = if k < halfPrimePlus1
|
||||
then if k > fromIntegral fieldBits
|
||||
then 0
|
||||
else MkF (shiftR x (fromInteger k))
|
||||
else shiftLeft (MkF x) (fieldPrime - k)
|
||||
|
||||
shiftLeft :: F -> Integer -> F
|
||||
shiftLeft (MkF x) k = if k < halfPrimePlus1
|
||||
then if k > fromIntegral fieldBits
|
||||
then 0
|
||||
else toF (shiftL x (fromIntegral k) .&. fieldMask)
|
||||
else shiftRight (MkF x) (fieldPrime - k)
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
{-
|
||||
notYet :: a
|
||||
notYet = error "not yet implemented"
|
||||
-}
|
||||
104
src/Witness.hs
Normal file
104
src/Witness.hs
Normal file
@ -0,0 +1,104 @@
|
||||
|
||||
{-# LANGUAGE StrictData #-}
|
||||
module Witness where
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
import Data.Array
|
||||
import Data.Word
|
||||
|
||||
import qualified Data.Map as Map ; import Data.Map (Map )
|
||||
import qualified Data.IntMap as IntMap ; import Data.IntMap (IntMap)
|
||||
|
||||
import BN254
|
||||
|
||||
import qualified Semantics as S ; import Semantics ( PrimOp , evalPrimOp )
|
||||
import qualified Graph as G ; import Graph ( Graph(..) , Node(..) , UnoOpNode(..) , DuoOpNode(..) , TresOpNode(..) , SignalDescription(..) )
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
type Witness = Array Int F
|
||||
|
||||
type Inputs = Map String [Integer]
|
||||
|
||||
witnessCalc :: Inputs -> Graph -> Witness
|
||||
witnessCalc inputs (Graph nodes meta) = witness where
|
||||
nodesArr = listArray (0,length nodes-1) nodes
|
||||
rawWitness = evaluateNodes rawInputs nodesArr
|
||||
rawInputs = convertInputs (G.inputSignals meta) inputs
|
||||
mapping_ = G.fromWitnessMapping (G.witnessMapping meta)
|
||||
wtnslen = length mapping_
|
||||
mapping = listArray (0,wtnslen-1) mapping_
|
||||
witness = listArray (0,wtnslen-1) $ [ rawWitness!(fromIntegral (mapping!i)) | i<-[0..wtnslen-1] ]
|
||||
|
||||
convertInputs :: [(String,SignalDescription)] -> Map String [Integer] -> IntMap F
|
||||
convertInputs descTable inputTable = IntMap.fromList $ concatMap f descTable where
|
||||
f :: (String,SignalDescription) -> [(Int,F)]
|
||||
f (name,desc) = case Map.lookup name inputTable of
|
||||
Nothing -> error $ "input signal `" ++ name ++ "` not found in the given inputs!"
|
||||
Just values -> if length values /= fromIntegral (signalLength desc)
|
||||
then error $ "input signal `" ++ name ++ "` has incorrect size"
|
||||
else let ofs = fromIntegral (signalOffset desc) :: Int
|
||||
in (0,1) : zip [ofs..] (map toF values)
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
type RawInputs = IntMap F
|
||||
|
||||
evaluateNodes :: RawInputs -> Array Int Node -> Array Int F
|
||||
evaluateNodes inputs nodes = witness where
|
||||
|
||||
(a,b) = bounds nodes
|
||||
witness = array (0,b-a) $ [ (i-a, worker i) | i<-[a..b] ]
|
||||
|
||||
lkp :: Word32 -> F
|
||||
lkp i = witness!(fromIntegral i)
|
||||
|
||||
worker i = case nodes!i of
|
||||
AnInputNode (G.InputNode idx) -> getInputValue idx
|
||||
AConstantNode (G.ConstantNode big) -> fromBigUInt big
|
||||
AnUnoOpNode unoOpNode -> evalPrimOp (fmap lkp $ unoToPrimOp unoOpNode )
|
||||
ADuoOpNode duoOpNode -> evalPrimOp (fmap lkp $ duoToPrimOp duoOpNode )
|
||||
ATresOpNode tresOpNode -> evalPrimOp (fmap lkp $ tresToPrimOp tresOpNode)
|
||||
|
||||
getInputValue j = case IntMap.lookup (fromIntegral j) inputs of
|
||||
Just y -> y
|
||||
Nothing -> error ("input value not found at index " ++ show j)
|
||||
|
||||
fromBigUInt (G.BigUInt bs) = fromBytesLE bs
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
unoToPrimOp :: UnoOpNode -> PrimOp Word32
|
||||
unoToPrimOp (UnoOpNode op arg1) = case op of
|
||||
G.Neg -> S.Neg arg1
|
||||
G.Id -> S.Id arg1
|
||||
|
||||
duoToPrimOp :: DuoOpNode -> PrimOp Word32
|
||||
duoToPrimOp (DuoOpNode op arg1 arg2) = case op of
|
||||
G.Mul -> S.Mul arg1 arg2
|
||||
G.Div -> S.Div arg1 arg2
|
||||
G.Add -> S.Add arg1 arg2
|
||||
G.Sub -> S.Sub arg1 arg2
|
||||
G.Pow -> S.Pow arg1 arg2
|
||||
G.Idiv -> S.Idiv arg1 arg2
|
||||
G.Mod -> S.Mod arg1 arg2
|
||||
G.Eq_ -> S.Eq_ arg1 arg2
|
||||
G.Neq -> S.Neq arg1 arg2
|
||||
G.Lt -> S.Lt arg1 arg2
|
||||
G.Gt -> S.Gt arg1 arg2
|
||||
G.Leq -> S.Leq arg1 arg2
|
||||
G.Geq -> S.Geq arg1 arg2
|
||||
G.Land -> S.Land arg1 arg2
|
||||
G.Lor -> S.Lor arg1 arg2
|
||||
G.Shl -> S.Shl arg1 arg2
|
||||
G.Shr -> S.Shr arg1 arg2
|
||||
G.Bor -> S.Bor arg1 arg2
|
||||
G.Band -> S.Band arg1 arg2
|
||||
G.Bxor -> S.Bxor arg1 arg2
|
||||
|
||||
tresToPrimOp :: TresOpNode -> PrimOp Word32
|
||||
tresToPrimOp (TresOpNode op arg1 arg2 arg3) = case op of
|
||||
G.TernCond -> S.Cond arg1 arg2 arg3
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
Loading…
x
Reference in New Issue
Block a user