more minor things

This commit is contained in:
Balazs Komuves 2026-05-05 23:41:31 +02:00
parent 3d36140218
commit eaac469766
No known key found for this signature in database
GPG Key ID: F63B7AEF18435562
3 changed files with 27 additions and 7 deletions

View File

@ -9,8 +9,10 @@ module Leopard.Binding where
import Data.Word import Data.Word
import Data.Array import Data.Array
import Data.Maybe import Data.Maybe
import Data.IORef
import Control.Monad import Control.Monad
import System.IO.Unsafe
import Foreign.C import Foreign.C
import Foreign.C.Types import Foreign.C.Types
@ -64,14 +66,28 @@ decodeLeopardResult result = case result of
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- * C++ bindings -- * C++ bindings
{-# NOINLINE initLeopard #-} {-# NOINLINE initLeopard' #-}
initLeopard :: IO () initLeopard' :: IO ()
initLeopard = do initLeopard' = do
res <- cpp_leo_init leo_VERSION res <- cpp_leo_init leo_VERSION
if (res == 0) if (res == 0)
then return () then return ()
else fail "Leopard initialization failed" else fail "Leopard initialization failed"
-- I just hope it's not a per-thread initialization... :)
-- (then we would need `Map ThreadId Bool`)
{-# NOINLINE theInitializedFlag #-}
theInitializedFlag :: IORef Bool
theInitializedFlag = unsafePerformIO (newIORef False)
{-# NOINLINE initLeopard #-}
initLeopard :: IO ()
initLeopard = do
ok <- readIORef theInitializedFlag
unless ok $ do
initLeopard'
writeIORef theInitializedFlag True
withLeopard :: IO a -> IO a withLeopard :: IO a -> IO a
withLeopard action = do withLeopard action = do
initLeopard initLeopard
@ -149,14 +165,14 @@ unsafeDecodeIO :: ECParams -> Array Int (Maybe ByteString) -> IO (Either Leopard
unsafeDecodeIO ecParams@(ECParams k n) mbChunks = do unsafeDecodeIO ecParams@(ECParams k n) mbChunks = do
let m = n - k let m = n - k
work_cnt <- cpp_leo_decode_work_count (fromIntegral k) (fromIntegral m) work_cnt <- cpp_leo_decode_work_count (fromIntegral k) (fromIntegral m)
when (work_cnt == 0) $ fail "edeode: `leo_decode_work_count` claims invalid input" when (work_cnt == 0) $ fail "decode: `leo_decode_work_count` claims invalid input"
let work_cnt_int = fromIntegral work_cnt :: Int let work_cnt_int = fromIntegral work_cnt :: Int
let nchunks = arrayLength mbChunks let nchunks = arrayLength mbChunks
let sizes = map B.length (catMaybes $ elems mbChunks) let sizes = map B.length (catMaybes $ elems mbChunks)
let mb_chunk_size = isUniformList sizes let mb_chunk_size = isUniformList sizes
unless (n == nchunks) $ fail "encode: we need exactly N encoded chunks" unless (n == nchunks) $ fail "decode: we need exactly N encoded chunks"
unless (isJust mb_chunk_size) $ fail "decode: chunk size must be uniform" unless (isJust mb_chunk_size) $ fail "decode: chunk size must be uniform"
let chunk_size = fromJust mb_chunk_size let chunk_size = fromJust mb_chunk_size

View File

@ -108,6 +108,10 @@ maskRandomly k arr = do
return $ listArray (u,v) return $ listArray (u,v)
[ if b then Just x else Nothing | (x,b) <- zip (elems arr) (elems mask) ] [ if b then Just x else Nothing | (x,b) <- zip (elems arr) (elems mask) ]
-- | There will be @k@ @Nothing@-s in the resulting list
maskListRandomly :: Int -> [a] -> IO [Maybe a]
maskListRandomly k xs = elems <$> maskRandomly k (arrayFromList xs)
-- | @randomBoolMask n k@ will give you @k@ falses and @(n-k)@ trues -- | @randomBoolMask n k@ will give you @k@ falses and @(n-k)@ trues
randomBoolMask :: Int -> Int -> IO (Array Int Bool) randomBoolMask :: Int -> Int -> IO (Array Int Bool)
randomBoolMask n k = go k trues where randomBoolMask n k = go k trues where

View File

@ -24,8 +24,8 @@ data ECParams = ECParams
deriving (Eq,Show) deriving (Eq,Show)
-- | Number of \"parity\" chunks -- | Number of \"parity\" chunks
ecM :: ECParams -> Int _ecM :: ECParams -> Int
ecM params = _ecN params - _ecK params _ecM params = _ecN params - _ecK params
isValidECParams :: ECParams -> Bool isValidECParams :: ECParams -> Bool
isValidECParams (ECParams k n) = and isValidECParams (ECParams k n) = and