-- | Parsing the graph binary format {-# LANGUAGE Strict, PackageImports, BangPatterns #-} module Parser where -------------------------------------------------------------------------------- import Data.Bits import Data.Word import Data.List import Data.Ord import Control.Monad import Control.Applicative import System.IO import Data.ByteString.Lazy ( ByteString ) import qualified Data.ByteString.Lazy as L import qualified Data.ByteString.Lazy.Char8 as LC import "binary" Data.Binary.Get import "binary" Data.Binary.Get.Internal ( lookAhead ) import "binary" Data.Binary.Builder as Builder import Graph -------------------------------------------------------------------------------- {- test :: IO () test = do Right graph <- parseGraphFile "../tmp/graph.bin" print graph -} {- nodeEx1 = 0x06 : nodeEx1' :: [Word8] nodeEx2 = 0x08 : nodeEx2' :: [Word8] nodeEx1' = 0x22 : 0x04 : duoNodeEx1 :: [Word8] nodeEx2' = 0x22 : 0x06 : duoNodeEx2 :: [Word8] duoNodeEx1 = [ 0x10 , 0x05 , 0x18 , 0x05 ] :: [Word8] duoNodeEx2 = [ 0x08 , 0x02 , 0x10 , 0x03 , 0x18 , 0x04 ] :: [Word8] -} -------------------------------------------------------------------------------- type Msg = String parseGraphFile :: FilePath -> IO (Either Msg Graph) parseGraphFile fname = do h <- openBinaryFile fname ReadMode ei <- readGraphFile h hClose h return ei hGetBytes :: Handle -> Int -> IO ByteString hGetBytes h n = L.hGet h (fromIntegral n) hSeekInt :: Handle -> Int -> IO () hSeekInt h ofs = hSeek h AbsoluteSeek (fromIntegral ofs) readGraphFile :: Handle -> IO (Either Msg Graph) readGraphFile h = do flen <- (fromIntegral :: Integer -> Int) <$> hFileSize h magic <- hGetBytes h (length magicHeader) if magic /= LC.pack magicHeader then return $ Left "magic header not found or invalid" else do hSeekInt h (flen - 8) offset <- (fromIntegral . runGet getWord64le) <$> hGetBytes h 8 -- putStrLn $ "metadata offset = " ++ show offset if (offset >= flen) || (offset <= 18) then return $ Left "invalid final `graphMetaData` offset bytes" else do hSeekInt h (length magicHeader) part1 <- hGetBytes h (offset - length magicHeader) part2 <- hGetBytes h (flen - offset - 8) return (Right $ Graph (parseNodes part1) (parseMeta part2)) magicHeader :: String magicHeader = "wtns.graph.001" parseNodes :: ByteString -> [Node] parseNodes = runGet getNodes parseMeta :: ByteString -> GraphMetaData parseMeta = runGet getMetaData -------------------------------------------------------------------------------- varInt' :: Get Word64 varInt' = go 0 where go !cnt = if cnt >= 8 then return 0 else do w <- getWord8 if (w < 128) then return (fromIntegral w) else do let x = fromIntegral (w .&. 127) y <- go (cnt+1) return (x + 128*y) varInt :: Get Int varInt = fromIntegral <$> varInt' varUInt :: Get Word32 varUInt = fromIntegral <$> varInt' -------------------------------------------------------------------------------- expectingError :: Int -> String -> Int -> Get a expectingError actual what shouldbe = do error $ what ++ ": expecting field " ++ show shouldbe ++ "; got " ++ show actual ++ " instead" getNodes :: Get [Node] getNodes = do n <- getWord64le replicateM (fromIntegral n) getNode -- | with varint length prefix getNode :: Get Node getNode = do len <- varInt bs <- getLazyByteString (fromIntegral len) return (runGet getNode' bs) -- | without varint length prefix getNode' :: Get Node getNode' = do nodetype <- getFieldId LEN case nodetype of 1 -> AnInputNode <$> getInputNode 2 -> AConstantNode <$> getConstantNode 3 -> AnUnoOpNode <$> getUnoOpNode 4 -> ADuoOpNode <$> getDuoOpNode 5 -> ATresOpNode <$> getTresOpNode _ -> error "unexpected node type" getInputNode :: Get InputNode getInputNode = do SomeNode idx _ _ _ <- getSomeNode return (InputNode idx) getConstantNode :: Get ConstantNode getConstantNode = do len <- varInt bs <- getLazyByteString (fromIntegral len) return $ ConstantNode (runGet getBigUInt bs) getBigUInt' :: FieldId -> Get BigUInt getBigUInt' expectedId = do fld <- getFieldId LEN if fld /= expectedId then expectingError fld "getBigUInt" expectedId else do len <- varInt bs <- getLazyByteString (fromIntegral len) return $ BigUInt (runGet getByteList bs) getBigUInt ::Get BigUInt getBigUInt = getBigUInt' 1 getByteList :: Get [Word8] getByteList = do fld <- getFieldId LEN if fld /= 1 then expectingError fld "getByteList" 1 else do len <- varInt bs <- getLazyByteString (fromIntegral len) return (L.unpack bs) getString' :: FieldId -> Get String getString' expectedId = do fld <- getFieldId LEN if fld /= expectedId then expectingError fld "getString" expectedId else do len <- varInt bs <- getLazyByteString (fromIntegral len) return (LC.unpack bs) getUnoOpNode :: Get UnoOpNode getUnoOpNode = do SomeNode op arg1 _ _ <- getSomeNode return (UnoOpNode (wordToEnum op) arg1) getDuoOpNode :: Get DuoOpNode getDuoOpNode = do SomeNode op arg1 arg2 _ <- getSomeNode return (DuoOpNode (wordToEnum op) arg1 arg2) getTresOpNode :: Get TresOpNode getTresOpNode = do SomeNode op arg1 arg2 arg3 <- getSomeNode return (TresOpNode (wordToEnum op) arg1 arg2 arg3) wordToEnum :: Enum a => Word32 -> a wordToEnum w = toEnum (fromIntegral w) -------------------------------------------------------------------------------- data SomeNode = SomeNode { field1 :: Word32 , field2 :: Word32 , field3 :: Word32 , field4 :: Word32 } deriving Show defaultSomeNode :: SomeNode defaultSomeNode = SomeNode 0 0 0 0 insert1 :: (Int,Word32) -> SomeNode -> SomeNode insert1 (idx,val) old = case idx of 1 -> old { field1 = val } 2 -> old { field2 = val } 3 -> old { field3 = val } 4 -> old { field4 = val } insertMany :: [(Int,Word32)] -> SomeNode -> SomeNode insertMany list old = foldl' (flip insert1) old list getSomeNode :: Get SomeNode getSomeNode = do len <- varInt bs <- getLazyByteString (fromIntegral len) let list = runGet getRecord bs return $ insertMany list defaultSomeNode -------------------------------------------------------------------------------- -- TODO: refactor this mess getMetaData :: Get GraphMetaData getMetaData = do len <- varInt mapping <- getWitnessMapping inputs <- getCircuitInputs prime <- getPrime return $ GraphMetaData mapping inputs prime getPrime :: Get Prime getPrime = do number <- getBigUInt' 3 name <- getString' 4 return $ Prime { primeNumber = number , primeName = name } getWitnessMapping :: Get WitnessMapping getWitnessMapping = do fld <- getFieldId LEN if fld /= 1 then expectingError fld "getWitnessMapping" 1 else do len <- varInt bs <- getLazyByteString (fromIntegral len) return $ WitnessMapping (runGet worker bs) where worker :: Get [Word32] worker = isEmpty >>= \b -> if b then return [] else (:) <$> varUInt <*> worker getCircuitInputs :: Get CircuitInputs getCircuitInputs = worker where {- worker :: Get [(String, SignalDescription)] worker = isEmpty >>= \b -> if b then return [] else (:) <$> getSingleInput <*> worker -} worker :: Get [(String, SignalDescription)] worker = do mb <- getSingleInput case mb of Nothing -> return [] Just this -> (this:) <$> worker getSingleInput :: Get (Maybe (String, SignalDescription)) getSingleInput = do fld <- lookAhead (getFieldId LEN) if fld /= 2 then return Nothing -- expectingError fld "getSingleInput" 2 else do _fld <- getFieldId LEN len <- varInt bs <- getLazyByteString (fromIntegral len) return $ Just $ runGet inputHelper bs inputHelper = do name <- getName signal <- getSignal return (name,signal) getName :: Get String getName = do fld <- getFieldId LEN if fld /= 1 then expectingError fld "getCircuitInputs/getName" 1 else do len <- varInt bs <- getLazyByteString (fromIntegral len) return (LC.unpack bs) getSignal :: Get SignalDescription getSignal = do fld <- getFieldId LEN if fld /= 2 then expectingError fld "getCircuitInputs/getSignal" 2 else do len <- varInt bs <- getLazyByteString (fromIntegral len) return $ runGet signalHelper bs signalHelper = do ofs <- getSignalOffset len <- getSignalLength return $ SignalDescription { signalOffset = ofs , signalLength = len } getSignalOffset = do fld <- getFieldId VARINT if fld /= 1 then expectingError fld "getCircuitInputs/getSignalOffset" 1 else varUInt getSignalLength = do fld <- getFieldId VARINT if fld /= 2 then expectingError fld "getCircuitInputs/getSignalLength" 2 else varUInt -------------------------------------------------------------------------------- -- * protobuf stuff -- | There are six wire types: VARINT, I64, LEN, SGROUP, EGROUP, and I32 data WireType = VARINT -- ^ used for: int32, int64, uint32, uint64, sint32, sint64, bool, enum | I64 -- ^ used for: fixed64, sfixed64, double | LEN -- ^ used for: string, bytes, embedded messages, packed repeated fields | SGROUP -- ^ used for: group start (deprecated) | EGROUP -- ^ used for: group end (deprecated) | I32 -- ^ fixed32, sfixed32, float deriving (Eq,Show,Enum,Bounded) type FieldId = Int getFieldId :: WireType -> Get FieldId getFieldId wty = do tag <- getWord8 let (fld,wty') = decodeTag tag if wty == wty' then return fld else error "getFieldId: unexpected protobuf wire type" decodeTag_ :: Word8 -> FieldId decodeTag_ = fst . decodeTag decodeTag :: Word8 -> (FieldId, WireType) decodeTag w = (fld , wty) where fld = fromIntegral (shiftR w 3) wty = toEnum (fromIntegral (w .&. 7)) -- (index, value) pair getEntry :: Get (Int,Word32) getEntry = do idx <- getFieldId VARINT val <- varUInt return (idx,val) -- list of (index, value) pairs getRecord :: Get [(Int,Word32)] getRecord = sort <$> go where go = isEmpty >>= \b -> if b then return [] else (:) <$> getEntry <*> go --------------------------------------------------------------------------------