From eaac469766e90083ed6edb9473c66337d8ceb68a Mon Sep 17 00:00:00 2001 From: Balazs Komuves Date: Tue, 5 May 2026 23:41:31 +0200 Subject: [PATCH] more minor things --- src/Leopard/Binding.hs | 26 +++++++++++++++++++++----- src/Leopard/Misc.hs | 4 ++++ src/Leopard/Types.hs | 4 ++-- 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/src/Leopard/Binding.hs b/src/Leopard/Binding.hs index ed10809..af2708f 100644 --- a/src/Leopard/Binding.hs +++ b/src/Leopard/Binding.hs @@ -9,8 +9,10 @@ module Leopard.Binding where import Data.Word import Data.Array import Data.Maybe +import Data.IORef import Control.Monad +import System.IO.Unsafe import Foreign.C import Foreign.C.Types @@ -64,14 +66,28 @@ decodeLeopardResult result = case result of -------------------------------------------------------------------------------- -- * C++ bindings -{-# NOINLINE initLeopard #-} -initLeopard :: IO () -initLeopard = do +{-# NOINLINE initLeopard' #-} +initLeopard' :: IO () +initLeopard' = do res <- cpp_leo_init leo_VERSION if (res == 0) then return () else fail "Leopard initialization failed" +-- I just hope it's not a per-thread initialization... :) +-- (then we would need `Map ThreadId Bool`) +{-# NOINLINE theInitializedFlag #-} +theInitializedFlag :: IORef Bool +theInitializedFlag = unsafePerformIO (newIORef False) + +{-# NOINLINE initLeopard #-} +initLeopard :: IO () +initLeopard = do + ok <- readIORef theInitializedFlag + unless ok $ do + initLeopard' + writeIORef theInitializedFlag True + withLeopard :: IO a -> IO a withLeopard action = do initLeopard @@ -149,14 +165,14 @@ unsafeDecodeIO :: ECParams -> Array Int (Maybe ByteString) -> IO (Either Leopard unsafeDecodeIO ecParams@(ECParams k n) mbChunks = do let m = n - k work_cnt <- cpp_leo_decode_work_count (fromIntegral k) (fromIntegral m) - when (work_cnt == 0) $ fail "edeode: `leo_decode_work_count` claims invalid input" + when (work_cnt == 0) $ fail "decode: `leo_decode_work_count` claims invalid input" let work_cnt_int = fromIntegral work_cnt :: Int let nchunks = arrayLength mbChunks let sizes = map B.length (catMaybes $ elems mbChunks) let mb_chunk_size = isUniformList sizes - unless (n == nchunks) $ fail "encode: we need exactly N encoded chunks" + unless (n == nchunks) $ fail "decode: we need exactly N encoded chunks" unless (isJust mb_chunk_size) $ fail "decode: chunk size must be uniform" let chunk_size = fromJust mb_chunk_size diff --git a/src/Leopard/Misc.hs b/src/Leopard/Misc.hs index e5a6951..ddb4ae5 100644 --- a/src/Leopard/Misc.hs +++ b/src/Leopard/Misc.hs @@ -108,6 +108,10 @@ maskRandomly k arr = do return $ listArray (u,v) [ if b then Just x else Nothing | (x,b) <- zip (elems arr) (elems mask) ] +-- | There will be @k@ @Nothing@-s in the resulting list +maskListRandomly :: Int -> [a] -> IO [Maybe a] +maskListRandomly k xs = elems <$> maskRandomly k (arrayFromList xs) + -- | @randomBoolMask n k@ will give you @k@ falses and @(n-k)@ trues randomBoolMask :: Int -> Int -> IO (Array Int Bool) randomBoolMask n k = go k trues where diff --git a/src/Leopard/Types.hs b/src/Leopard/Types.hs index 9cb0679..f154679 100644 --- a/src/Leopard/Types.hs +++ b/src/Leopard/Types.hs @@ -24,8 +24,8 @@ data ECParams = ECParams deriving (Eq,Show) -- | Number of \"parity\" chunks -ecM :: ECParams -> Int -ecM params = _ecN params - _ecK params +_ecM :: ECParams -> Int +_ecM params = _ecN params - _ecK params isValidECParams :: ECParams -> Bool isValidECParams (ECParams k n) = and