diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..68df01b --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +.DS_Store +*.hi +*.o +tmp/ +build/ +.ghc.environment.* \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..efd1bf7 --- /dev/null +++ b/README.md @@ -0,0 +1,55 @@ + +Circom witness generators +------------------------- + +This piece of software takes the computation graph files generated by +[`circom-witnesscalc`](https://github.com/iden3/circom-witnesscalc), +and either interprets or compiles them to various algebra backends. + +### Implementation status + +Compiler (in Haskell): + +- [x] parsing the graph files +- [x] naive interpreter +- [ ] constantine backend +- [ ] zikkurat backend +- [ ] arkworks backend + +Nim witness generator: + +- [x] parsing the graph files +- [ ] generating the witness +- [ ] proper error handling + +### Testing & correctness + +I haven't yet done any proper testing, apart from "works for our purposes". + +### Circuit optimizations + +NOTE: you _have to_ run `circom` with the `--O2` options, otherwise the +witness format will be most probably incompatible with the the one generated +by `circom-witnesscalc`. + +### Graph file format + +`circom-witnesscalc` produces binary files encoding a computation graph. + +This has the following format: + +- magic header: "wtns.graph.001" (14 bytes) +- number of nodes (8 bytes little endian) +- list of nodes, in protobuf serialized format. Each prefixed by a `varint` length +- `GraphMetaData`, in protobuf +- 8 bytes offset (yes, _at the very end_...), pointing to the start of `GraphMetaData` + +Node format: + +- varint length prefix +- protobuf tag-byte (lower 3 bits are `0x02`, upper bits are node type 1..5) +- varint length +- several record fields + - protobuf tag-byte (lower 3 bits are 0x00, upper bits are field index 1,2,3,4) + - value is varint word32, except for ConstantNode when it's a little-endian bytes + (wrapped a few times, because protobuf is being protobuf...) diff --git a/src/BN254.hs b/src/BN254.hs new file mode 100644 index 0000000..e47ebfd --- /dev/null +++ b/src/BN254.hs @@ -0,0 +1,191 @@ + +-- | The BN254 scalar field + +{-# LANGUAGE Strict, BangPatterns #-} +module BN254 where + +-------------------------------------------------------------------------------- + +import Prelude hiding (div) +import qualified Prelude + +import Data.Bits +import Data.Word +import Data.Ratio + +import System.Random +import Text.Printf + +import Misc + +-------------------------------------------------------------------------------- + +fieldPrime :: Integer +fieldPrime = 21888242871839275222246405745257275088548364400416034343698204186575808495617 + +modP :: Integer -> Integer +modP x = mod x fieldPrime + +halfPrimePlus1 :: Integer +halfPrimePlus1 = 1 + Prelude.div fieldPrime 2 + +-------------------------------------------------------------------------------- + +newtype F + = MkF Integer + deriving (Eq,Show) + +fromF :: F -> Integer +fromF (MkF x) = x + +-- from the circom docs: @val(z) = z-p if p/2 +1 <= z < p@ +signedFromF :: F -> Integer +signedFromF (MkF x) = if x >= halfPrimePlus1 then x - fieldPrime else x + +toF :: Integer -> F +toF = MkF . modP + +isZero :: F -> Bool +isZero (MkF x) = (x == 0) + +fromBool :: Bool -> F +fromBool b = if b then 1 else 0 + +toBool :: F -> Bool +toBool = not . isZero + +-------------------------------------------------------------------------------- + +neg :: F -> F +neg (MkF x) = toF (negate x) + +add :: F -> F -> F +add (MkF x) (MkF y) = toF (x+y) + +sub :: F -> F -> F +sub (MkF x) (MkF y) = toF (x-y) + +mul :: F -> F -> F +mul (MkF x) (MkF y) = toF (x*y) + +instance Num F where + fromInteger = toF + negate = neg + (+) = add + (-) = sub + (*) = mul + abs = id + signum _ = toF 1 + +square :: F -> F +square x = x*x + +rndF :: IO F +rndF = MkF <$> randomRIO (0,fieldPrime-1) + +-------------------------------------------------------------------------------- + +pow :: F -> Integer -> F +pow x0 exponent + | exponent < 0 = error "power: expecting positive exponent" + | otherwise = go 1 x0 exponent + where + go !acc _ 0 = acc + go !acc s e = go acc' s' (shiftR e 1) where + s' = s*s + acc' = if e .&. 1 == 0 then acc else acc*s + +invNaive :: F -> F +invNaive x = pow x (fieldPrime - 2) + +divNaive :: F -> F -> F +divNaive x y = x * invNaive y + +-------------------------------------------------------------------------------- + +instance Fractional F where + fromRational q = fromInteger (numerator q) / fromInteger (denominator q) + recip = inv + (/) = div + +-------------------------------------------------------------------------------- + +fromBytesLE :: [Word8] -> F +fromBytesLE = toF . integerFromBytesLE + +integerFromBytesLE :: [Word8] -> Integer +integerFromBytesLE = go where + go [] = 0 + go (b:bs) = fromIntegral b + (shiftL (go bs) 8) + +-------------------------------------------------------------------------------- + +instance ShowHex Integer where showHex = printf "0x%x" +instance ShowHex F where showHex (MkF x) = showHex x + +-------------------------------------------------------------------------------- + +-- | Inversion (using Euclid's algorithm) +inv :: F -> F +inv (MkF a) + | a == 0 = 0 -- error "field inverse of zero (generic prime)" + | otherwise = MkF (euclid 1 0 a fieldPrime) + +-- | Division via Euclid's algorithm +div :: F -> F -> F +div (MkF a) (MkF b) + | b == 0 = 0 -- error "field division by zero (generic prime)" + | otherwise = MkF (euclid a 0 b fieldPrime) + +-------------------------------------------------------------------------------- +-- * Euclidean algorithm + +-- | Extended binary Euclidean algorithm +euclid :: Integer -> Integer -> Integer -> Integer -> Integer +euclid !x1 !x2 !u !v = go x1 x2 u v where + + p = fieldPrime + + halfp1 = shiftR (p+1) 1 + + modp :: Integer -> Integer + modp n = mod n p + + -- Inverse using the binary Euclidean algorithm + euclid :: Integer -> Integer + euclid a + | a == 0 = 0 + | otherwise = go 1 0 a p + + go :: Integer -> Integer -> Integer -> Integer -> Integer + go !x1 !x2 !u !v + | u==1 = x1 + | v==1 = x2 + | otherwise = stepU x1 x2 u v + + stepU :: Integer -> Integer -> Integer -> Integer -> Integer + stepU !x1 !x2 !u !v = if even u + then let u' = shiftR u 1 + x1' = if even x1 then shiftR x1 1 else shiftR x1 1 + halfp1 + in stepU x1' x2 u' v + else stepV x1 x2 u v + + stepV :: Integer -> Integer -> Integer -> Integer -> Integer + stepV !x1 !x2 !u !v = if even v + then let v' = shiftR v 1 + x2' = if even x2 then shiftR x2 1 else shiftR x2 1 + halfp1 + in stepV x1 x2' u v' + else final x1 x2 u v + + final :: Integer -> Integer -> Integer -> Integer -> Integer + final !x1 !x2 !u !v = if u>=v + + then let u' = u-v + x1' = if x1 >= x2 then modp (x1-x2) else modp (x1+p-x2) + in go x1' x2 u' v + + else let v' = v-u + x2' = if x2 >= x1 then modp (x2-x1) else modp (x2+p-x1) + in go x1 x2' u v' + +-------------------------------------------------------------------------------- diff --git a/src/Graph.hs b/src/Graph.hs new file mode 100644 index 0000000..f71dac0 --- /dev/null +++ b/src/Graph.hs @@ -0,0 +1,104 @@ + +{-# LANGUAGE StrictData #-} +module Graph where + +-------------------------------------------------------------------------------- + +import Text.Printf + +-------------------------------------------------------------------------------- + +import Data.Word + +data Graph = Graph + { graphNodes :: [Node] + , graphMeta :: GraphMetaData + } + deriving Show + +-------------------------------------------------------------------------------- + +-- | Unary operations +data UnoOp + = Neg -- ^ @= 0@ + | Id -- ^ @= 1@ + deriving (Eq,Enum,Bounded,Show) + +data DuoOp + = Mul -- ^ @= 0@ + | Div -- ^ @= 1@ + | Add -- ^ @= 2@ + | Sub -- ^ @= 3@ + | Pow -- ^ @= 4@ + | Idiv -- ^ @= 5@ + | Mod -- ^ @= 6@ + | Eq_ -- ^ @= 7@ + | Neq -- ^ @= 8@ + | Lt -- ^ @= 9@ + | Gt -- ^ @= 10@ + | Leq -- ^ @= 11@ + | Geq -- ^ @= 12@ + | Land -- ^ @= 13@ + | Lor -- ^ @= 14@ + | Shl -- ^ @= 15@ + | Shr -- ^ @= 16@ + | Bor -- ^ @= 17@ + | Band -- ^ @= 18@ + | Bxor -- ^ @= 19@ + deriving (Eq,Enum,Bounded,Show) + +data TresOp + = TernCond -- ^ @= 0@ + deriving (Eq,Enum,Bounded,Show) + +-------------------------------------------------------------------------------- + +newtype BigUInt + = BigUInt [Word8] -- ^ little endian + +showBigUInt :: BigUInt -> String +showBigUInt (BigUInt bytes) = "0x" ++ concatMap f (reverse bytes) where + f :: Word8 -> String + f = printf "%02x" + +instance Show BigUInt where show = showBigUInt + +newtype InputNode + = InputNode Word32 + deriving (Show) + +newtype ConstantNode + = ConstantNode BigUInt + deriving (Show) + +data UnoOpNode = UnoOpNode !UnoOp !Word32 deriving (Show) +data DuoOpNode = DuoOpNode !DuoOp !Word32 !Word32 deriving (Show) +data TresOpNode = TresOpNode !TresOp !Word32 !Word32 !Word32 deriving (Show) + +data Node + = AnInputNode InputNode -- @= 1@ + | AConstantNode ConstantNode -- @= 2@ + | AnUnoOpNode UnoOpNode -- @= 3@ + | ADuoOpNode DuoOpNode -- @= 4@ + | ATresOpNode TresOpNode -- @= 5@ + deriving (Show) + +data SignalDescription = SignalDescription + { signalOffset :: !Word32 + , signalLength :: !Word32 + } + deriving (Show) + +newtype WitnessMapping + = WitnessMapping { fromWitnessMapping :: [Word32] } + deriving (Show) + +type CircuitInputs = [(String, SignalDescription)] + +data GraphMetaData = GraphMetaData + { witnessMapping :: WitnessMapping + , inputSignals :: CircuitInputs + } + deriving (Show) + +-------------------------------------------------------------------------------- diff --git a/src/Main.hs b/src/Main.hs new file mode 100644 index 0000000..e4db71d --- /dev/null +++ b/src/Main.hs @@ -0,0 +1,34 @@ + +module Main where + +-------------------------------------------------------------------------------- + +import Data.Map (Map) +import qualified Data.Map as Map + +import Witness +import Parser + +-------------------------------------------------------------------------------- + +(~>) :: String -> a -> (String, a) +(~>) = (,) + +infix 2 ~> + +testInputs :: Map String [Integer] +testInputs = Map.fromList + [ "a" ~> [0xff01] + , "b" ~> [0xff02] + ] + +-------------------------------------------------------------------------------- + +main :: IO () +main = do + Right graph <- parseGraphFile "../tmp/graph2.bin" + putStrLn "" + print graph + let wtns = witnessCalc testInputs graph + putStrLn "" + print wtns \ No newline at end of file diff --git a/src/Misc.hs b/src/Misc.hs new file mode 100644 index 0000000..19bf8ae --- /dev/null +++ b/src/Misc.hs @@ -0,0 +1,33 @@ + +{-# LANGUAGE Strict #-} +module Misc where + +-------------------------------------------------------------------------------- + +import Data.Bits + +-------------------------------------------------------------------------------- + +class ShowHex a where + showHex :: a -> String + +printHex :: ShowHex a => a -> IO () +printHex x = putStrLn (showHex x) + +-------------------------------------------------------------------------------- +-- Integer logarithm + +-- | Largest integer @k@ such that @2^k@ is smaller or equal to @n@ +integerLog2 :: Integer -> Int +integerLog2 n = go n where + go 0 = -1 + go k = 1 + go (shiftR k 1) + +-- | Smallest integer @k@ such that @2^k@ is larger or equal to @n@ +ceilingLog2 :: Integer -> Int +ceilingLog2 0 = 0 +ceilingLog2 n = 1 + go (n-1) where + go 0 = -1 + go k = 1 + go (shiftR k 1) + +-------------------------------------------------------------------------------- diff --git a/src/Parser.hs b/src/Parser.hs new file mode 100644 index 0000000..140b83f --- /dev/null +++ b/src/Parser.hs @@ -0,0 +1,345 @@ + +-- | Parsing the graph binary format + +{-# LANGUAGE Strict, PackageImports, BangPatterns #-} +module Parser where + +-------------------------------------------------------------------------------- + +import Data.Bits +import Data.Word +import Data.List +import Data.Ord + +import Control.Monad +import Control.Applicative +import System.IO + +import Data.ByteString.Lazy (ByteString) +import qualified Data.ByteString.Lazy as L +import qualified Data.ByteString.Lazy.Char8 as LC + +import "binary" Data.Binary.Get +import "binary" Data.Binary.Builder as Builder + +import Graph + +-------------------------------------------------------------------------------- + +{- +test :: IO () +test = do + Right graph <- parseGraphFile "../tmp/graph.bin" + print graph +-} + +{- +nodeEx1 = 0x06 : nodeEx1' :: [Word8] +nodeEx2 = 0x08 : nodeEx2' :: [Word8] + +nodeEx1' = 0x22 : 0x04 : duoNodeEx1 :: [Word8] +nodeEx2' = 0x22 : 0x06 : duoNodeEx2 :: [Word8] + +duoNodeEx1 = [ 0x10 , 0x05 , 0x18 , 0x05 ] :: [Word8] +duoNodeEx2 = [ 0x08 , 0x02 , 0x10 , 0x03 , 0x18 , 0x04 ] :: [Word8] +-} + +-------------------------------------------------------------------------------- + +type Msg = String + +parseGraphFile :: FilePath -> IO (Either Msg Graph) +parseGraphFile fname = do + h <- openBinaryFile fname ReadMode + ei <- readGraphFile h + hClose h + return ei + +hGetBytes :: Handle -> Int -> IO ByteString +hGetBytes h n = L.hGet h (fromIntegral n) + +hSeekInt :: Handle -> Int -> IO () +hSeekInt h ofs = hSeek h AbsoluteSeek (fromIntegral ofs) + +readGraphFile :: Handle -> IO (Either Msg Graph) +readGraphFile h = do + flen <- (fromIntegral :: Integer -> Int) <$> hFileSize h + magic <- hGetBytes h (length magicHeader) + if magic /= LC.pack magicHeader + then return $ Left "magic header not found or invalid" + else do + hSeekInt h (flen - 8) + offset <- (fromIntegral . runGet getWord64le) <$> hGetBytes h 8 + putStrLn $ "metadata offset = " ++ show offset + if (offset >= flen) || (offset <= 18) + then return $ Left "invalid final `graphMetaData` offset bytes" + else do + hSeekInt h (length magicHeader) + part1 <- hGetBytes h (offset - length magicHeader) + part2 <- hGetBytes h (flen - offset - 8) + return (Right $ Graph (parseNodes part1) (parseMeta part2)) + +magicHeader :: String +magicHeader = "wtns.graph.001" + +parseNodes :: ByteString -> [Node] +parseNodes = runGet getNodes + +parseMeta :: ByteString -> GraphMetaData +parseMeta = runGet getMetaData + +-------------------------------------------------------------------------------- + +varInt' :: Get Word64 +varInt' = go 0 where + go !cnt = if cnt >= 8 then return 0 else do + w <- getWord8 + if (w < 128) + then return (fromIntegral w) + else do + let x = fromIntegral (w .&. 127) + y <- go (cnt+1) + return (x + 128*y) + +varInt :: Get Int +varInt = fromIntegral <$> varInt' + +varUInt :: Get Word32 +varUInt = fromIntegral <$> varInt' + +-------------------------------------------------------------------------------- + +getNodes :: Get [Node] +getNodes = do + n <- getWord64le + replicateM (fromIntegral n) getNode + +-- | with varint length prefix +getNode :: Get Node +getNode = do + len <- varInt + bs <- getLazyByteString (fromIntegral len) + return (runGet getNode' bs) + +-- | without varint length prefix +getNode' :: Get Node +getNode' = do + nodetype <- getFieldId LEN + case nodetype of + 1 -> AnInputNode <$> getInputNode + 2 -> AConstantNode <$> getConstantNode + 3 -> AnUnoOpNode <$> getUnoOpNode + 4 -> ADuoOpNode <$> getDuoOpNode + 5 -> ATresOpNode <$> getTresOpNode + _ -> error "unexpected node type" + +getInputNode :: Get InputNode +getInputNode = do + SomeNode idx _ _ _ <- getSomeNode + return (InputNode idx) + +getConstantNode :: Get ConstantNode +getConstantNode = do + len <- varInt + bs <- getLazyByteString (fromIntegral len) + return $ ConstantNode (runGet getBigUInt bs) + +getBigUInt :: Get BigUInt +getBigUInt = do + fld <- getFieldId LEN + if fld /= 1 + then error "getBigUInt" + else do + len <- varInt + bs <- getLazyByteString (fromIntegral len) + return $ BigUInt (runGet getByteList bs) + +getByteList :: Get [Word8] +getByteList = do + fld <- getFieldId LEN + if fld /= 1 + then error "getByteList" + else do + len <- varInt + bs <- getLazyByteString (fromIntegral len) + return (L.unpack bs) + +getUnoOpNode :: Get UnoOpNode +getUnoOpNode = do + SomeNode op arg1 _ _ <- getSomeNode + return (UnoOpNode (wordToEnum op) arg1) + +getDuoOpNode :: Get DuoOpNode +getDuoOpNode = do + SomeNode op arg1 arg2 _ <- getSomeNode + return (DuoOpNode (wordToEnum op) arg1 arg2) + +getTresOpNode :: Get TresOpNode +getTresOpNode = do + SomeNode op arg1 arg2 arg3 <- getSomeNode + return (TresOpNode (wordToEnum op) arg1 arg2 arg3) + +wordToEnum :: Enum a => Word32 -> a +wordToEnum w = toEnum (fromIntegral w) + +-------------------------------------------------------------------------------- + +data SomeNode = SomeNode + { field1 :: Word32 + , field2 :: Word32 + , field3 :: Word32 + , field4 :: Word32 + } + deriving Show + +defaultSomeNode :: SomeNode +defaultSomeNode = SomeNode 0 0 0 0 + +insert1 :: (Int,Word32) -> SomeNode -> SomeNode +insert1 (idx,val) old = case idx of + 1 -> old { field1 = val } + 2 -> old { field2 = val } + 3 -> old { field3 = val } + 4 -> old { field4 = val } + +insertMany :: [(Int,Word32)] -> SomeNode -> SomeNode +insertMany list old = foldl' (flip insert1) old list + +getSomeNode :: Get SomeNode +getSomeNode = do + len <- varInt + bs <- getLazyByteString (fromIntegral len) + let list = runGet getRecord bs + return $ insertMany list defaultSomeNode + +-------------------------------------------------------------------------------- +-- TODO: refactor this mess + +getMetaData :: Get GraphMetaData +getMetaData = do + len <- varInt + mapping <- getWitnessMapping + inputs <- getCircuitInputs + return $ GraphMetaData mapping inputs + +getWitnessMapping :: Get WitnessMapping +getWitnessMapping = do + fld <- getFieldId LEN + if fld /= 1 + then error "getWitnessMapping: expecting field 1" + else do + len <- varInt + bs <- getLazyByteString (fromIntegral len) + return $ WitnessMapping (runGet worker bs) + where + worker :: Get [Word32] + worker = isEmpty >>= \b -> if b + then return [] + else (:) <$> varUInt <*> worker + +getCircuitInputs :: Get CircuitInputs +getCircuitInputs = worker where + + worker :: Get [(String, SignalDescription)] + worker = isEmpty >>= \b -> if b + then return [] + else (:) <$> getSingleInput <*> worker + + getSingleInput :: Get (String, SignalDescription) + getSingleInput = do + fld <- getFieldId LEN + if fld /= 2 + then error "getCircuitInputs: expecting field 2" + else do + len <- varInt + bs <- getLazyByteString (fromIntegral len) + return $ runGet inputHelper bs + + inputHelper = do + name <- getName + signal <- getSignal + return (name,signal) + + getName :: Get String + getName = do + fld <- getFieldId LEN + if fld /= 1 + then error "getCircuitInputs/getName: expecting field 1" + else do + len <- varInt + bs <- getLazyByteString (fromIntegral len) + return (LC.unpack bs) + + getSignal :: Get SignalDescription + getSignal = do + fld <- getFieldId LEN + if fld /= 2 + then error "getCircuitInputs/getSignal: expecting field 2" + else do + len <- varInt + bs <- getLazyByteString (fromIntegral len) + return $ runGet signalHelper bs + + signalHelper = do + ofs <- getSignalOffset + len <- getSignalLength + return $ SignalDescription { signalOffset = ofs , signalLength = len } + + getSignalOffset = do + fld <- getFieldId VARINT + if fld /= 1 + then error "getCircuitInputs/getSignalOffset: expecting field 1" + else varUInt + + getSignalLength = do + fld <- getFieldId VARINT + if fld /= 2 + then error "getCircuitInputs/getSignalLength: expecting field 2" + else varUInt + +-------------------------------------------------------------------------------- +-- * protobuf stuff + +-- | There are six wire types: VARINT, I64, LEN, SGROUP, EGROUP, and I32 +data WireType + = VARINT -- ^ used for: int32, int64, uint32, uint64, sint32, sint64, bool, enum + | I64 -- ^ used for: fixed64, sfixed64, double + | LEN -- ^ used for: string, bytes, embedded messages, packed repeated fields + | SGROUP -- ^ used for: group start (deprecated) + | EGROUP -- ^ used for: group end (deprecated) + | I32 -- ^ fixed32, sfixed32, float + deriving (Eq,Show,Enum,Bounded) + +type FieldId = Int + +getFieldId :: WireType -> Get FieldId +getFieldId wty = do + tag <- getWord8 + let (fld,wty') = decodeTag tag + if wty == wty' + then return fld + else error "getFieldId: unexpected protobuf wire type" + +decodeTag_ :: Word8 -> FieldId +decodeTag_ = fst . decodeTag + +decodeTag :: Word8 -> (FieldId, WireType) +decodeTag w = (fld , wty) where + fld = fromIntegral (shiftR w 3) + wty = toEnum (fromIntegral (w .&. 7)) + +-- (index, value) pair +getEntry :: Get (Int,Word32) +getEntry = do + idx <- getFieldId VARINT + val <- varUInt + return (idx,val) + +-- list of (index, value) pairs +getRecord :: Get [(Int,Word32)] +getRecord = sort <$> go where + go = isEmpty >>= \b -> if b + then return [] + else (:) <$> getEntry <*> go + +-------------------------------------------------------------------------------- diff --git a/src/Semantics.hs b/src/Semantics.hs new file mode 100644 index 0000000..ddc1f36 --- /dev/null +++ b/src/Semantics.hs @@ -0,0 +1,101 @@ + +-- | See for the official +-- semantics of the operations + +{-# LANGUAGE StrictData, DeriveFunctor #-} +module Semantics where + +-------------------------------------------------------------------------------- + +import Data.Bits + +import BN254 +import Misc + +-------------------------------------------------------------------------------- + +data PrimOp a + -- unary + = Neg a + | Id a + -- binary + | Mul a a + | Div a a + | Add a a + | Sub a a + | Pow a a + | Idiv a a + | Mod a a + | Eq_ a a + | Neq a a + | Lt a a + | Gt a a + | Leq a a + | Geq a a + | Land a a + | Lor a a + | Shl a a + | Shr a a + | Bor a a + | Band a a + | Bxor a a + -- ternary + | Cond a a a + deriving (Show,Functor) + +-------------------------------------------------------------------------------- + +evalPrimOp :: PrimOp F -> F +evalPrimOp prim = case prim of + Neg x -> neg x + Id x -> x + Mul x y -> mul x y + Div x y -> BN254.div x y + Add x y -> add x y + Sub x y -> sub x y + Pow x y -> pow x (fromF y) + Idiv x y -> if isZero y then 0 else toF $ Prelude.div (fromF x) (fromF y) + Mod x y -> if isZero y then 0 else toF $ Prelude.mod (fromF x) (fromF y) + Eq_ x y -> fromBool (x == y) + Neq x y -> fromBool (x /= y) + Lt x y -> fromBool (signedFromF x < signedFromF y) + Gt x y -> fromBool (signedFromF x > signedFromF y) + Leq x y -> fromBool (signedFromF x <= signedFromF y) + Geq x y -> fromBool (signedFromF x >= signedFromF y) + Land x y -> fromBool (toBool x && toBool y) + Lor x y -> fromBool (toBool x || toBool y) + Shl x y -> shiftLeft x (fromF y) + Shr x y -> shiftRight x (fromF y) + Bor x y -> toF (fromF x .|. fromF y) + Band x y -> toF (fromF x .&. fromF y) + Bxor x y -> toF (fromF x `xor` fromF y) + Cond b x y -> if toBool b then x else y + +-------------------------------------------------------------------------------- + +fieldMask :: Integer +fieldMask = shiftL 1 fieldBits - 1 + +fieldBits :: Int +fieldBits = ceilingLog2 fieldPrime + +shiftRight :: F -> Integer -> F +shiftRight (MkF x) k = if k < halfPrimePlus1 + then if k > fromIntegral fieldBits + then 0 + else MkF (shiftR x (fromInteger k)) + else shiftLeft (MkF x) (fieldPrime - k) + +shiftLeft :: F -> Integer -> F +shiftLeft (MkF x) k = if k < halfPrimePlus1 + then if k > fromIntegral fieldBits + then 0 + else toF (shiftL x (fromIntegral k) .&. fieldMask) + else shiftRight (MkF x) (fieldPrime - k) + +-------------------------------------------------------------------------------- + +{- +notYet :: a +notYet = error "not yet implemented" +-} diff --git a/src/Witness.hs b/src/Witness.hs new file mode 100644 index 0000000..002f1ea --- /dev/null +++ b/src/Witness.hs @@ -0,0 +1,104 @@ + +{-# LANGUAGE StrictData #-} +module Witness where + +-------------------------------------------------------------------------------- + +import Data.Array +import Data.Word + +import qualified Data.Map as Map ; import Data.Map (Map ) +import qualified Data.IntMap as IntMap ; import Data.IntMap (IntMap) + +import BN254 + +import qualified Semantics as S ; import Semantics ( PrimOp , evalPrimOp ) +import qualified Graph as G ; import Graph ( Graph(..) , Node(..) , UnoOpNode(..) , DuoOpNode(..) , TresOpNode(..) , SignalDescription(..) ) + +-------------------------------------------------------------------------------- + +type Witness = Array Int F + +type Inputs = Map String [Integer] + +witnessCalc :: Inputs -> Graph -> Witness +witnessCalc inputs (Graph nodes meta) = witness where + nodesArr = listArray (0,length nodes-1) nodes + rawWitness = evaluateNodes rawInputs nodesArr + rawInputs = convertInputs (G.inputSignals meta) inputs + mapping_ = G.fromWitnessMapping (G.witnessMapping meta) + wtnslen = length mapping_ + mapping = listArray (0,wtnslen-1) mapping_ + witness = listArray (0,wtnslen-1) $ [ rawWitness!(fromIntegral (mapping!i)) | i<-[0..wtnslen-1] ] + +convertInputs :: [(String,SignalDescription)] -> Map String [Integer] -> IntMap F +convertInputs descTable inputTable = IntMap.fromList $ concatMap f descTable where + f :: (String,SignalDescription) -> [(Int,F)] + f (name,desc) = case Map.lookup name inputTable of + Nothing -> error $ "input signal `" ++ name ++ "` not found in the given inputs!" + Just values -> if length values /= fromIntegral (signalLength desc) + then error $ "input signal `" ++ name ++ "` has incorrect size" + else let ofs = fromIntegral (signalOffset desc) :: Int + in (0,1) : zip [ofs..] (map toF values) + +-------------------------------------------------------------------------------- + +type RawInputs = IntMap F + +evaluateNodes :: RawInputs -> Array Int Node -> Array Int F +evaluateNodes inputs nodes = witness where + + (a,b) = bounds nodes + witness = array (0,b-a) $ [ (i-a, worker i) | i<-[a..b] ] + + lkp :: Word32 -> F + lkp i = witness!(fromIntegral i) + + worker i = case nodes!i of + AnInputNode (G.InputNode idx) -> getInputValue idx + AConstantNode (G.ConstantNode big) -> fromBigUInt big + AnUnoOpNode unoOpNode -> evalPrimOp (fmap lkp $ unoToPrimOp unoOpNode ) + ADuoOpNode duoOpNode -> evalPrimOp (fmap lkp $ duoToPrimOp duoOpNode ) + ATresOpNode tresOpNode -> evalPrimOp (fmap lkp $ tresToPrimOp tresOpNode) + + getInputValue j = case IntMap.lookup (fromIntegral j) inputs of + Just y -> y + Nothing -> error ("input value not found at index " ++ show j) + + fromBigUInt (G.BigUInt bs) = fromBytesLE bs + +-------------------------------------------------------------------------------- + +unoToPrimOp :: UnoOpNode -> PrimOp Word32 +unoToPrimOp (UnoOpNode op arg1) = case op of + G.Neg -> S.Neg arg1 + G.Id -> S.Id arg1 + +duoToPrimOp :: DuoOpNode -> PrimOp Word32 +duoToPrimOp (DuoOpNode op arg1 arg2) = case op of + G.Mul -> S.Mul arg1 arg2 + G.Div -> S.Div arg1 arg2 + G.Add -> S.Add arg1 arg2 + G.Sub -> S.Sub arg1 arg2 + G.Pow -> S.Pow arg1 arg2 + G.Idiv -> S.Idiv arg1 arg2 + G.Mod -> S.Mod arg1 arg2 + G.Eq_ -> S.Eq_ arg1 arg2 + G.Neq -> S.Neq arg1 arg2 + G.Lt -> S.Lt arg1 arg2 + G.Gt -> S.Gt arg1 arg2 + G.Leq -> S.Leq arg1 arg2 + G.Geq -> S.Geq arg1 arg2 + G.Land -> S.Land arg1 arg2 + G.Lor -> S.Lor arg1 arg2 + G.Shl -> S.Shl arg1 arg2 + G.Shr -> S.Shr arg1 arg2 + G.Bor -> S.Bor arg1 arg2 + G.Band -> S.Band arg1 arg2 + G.Bxor -> S.Bxor arg1 arg2 + +tresToPrimOp :: TresOpNode -> PrimOp Word32 +tresToPrimOp (TresOpNode op arg1 arg2 arg3) = case op of + G.TernCond -> S.Cond arg1 arg2 arg3 + +--------------------------------------------------------------------------------