300 lines
9.2 KiB
Haskell

{-# LANGUAGE ScopedTypeVariables #-}
module Misc where
--------------------------------------------------------------------------------
import Data.Bits
import Data.Array
import Data.List
import qualified Data.Set as Set ; import Data.Set (Set)
import Control.Monad
import Data.Binary
import Debug.Trace
--------------------------------------------------------------------------------
-- * Debug
debug_ :: Show a => a -> b -> b
debug_ x y = trace (">>> " ++ show x) y
debug :: Show a => String -> a -> b -> b
debug n x y = trace (">>> " ++ n ++ " = " ++ show x) y
debugPrint :: Show a => String -> a -> IO ()
debugPrint n x = putStrLn ("> " ++ n ++ " = " ++ show x)
--------------------------------------------------------------------------------
-- * Integers
isEven :: Integer -> Bool
isEven n = (n .&. 1) == 0
isOdd :: Integer -> Bool
isOdd n = (n .&. 1) /= 0
--------------------------------------------------------------------------------
-- * Strings
-- | The difference from 'unlines' is that this one doesn't add a final newline
unlines1 :: [String] -> String
unlines1 = intercalate "\n"
--------------------------------------------------------------------------------
-- * Log2
newtype Log2
= Log2 Int
deriving (Eq,Ord,Show,Num)
fromLog2 :: Log2 -> Int
fromLog2 (Log2 k)
| k >=0 = k
| otherwise = error "fromLog2: negative exponent"
exp2 :: Log2 -> Integer
exp2 (Log2 k)
| k >= 0 = shiftL 1 k
| otherwise = error "exp2: negative exponent"
exp2_ :: Log2 -> Int
exp2_ (Log2 k)
| k >= 0 = shiftL 1 k
| otherwise = error "exp2_: negative exponent"
-- | Smallest integer @k@ such that @2^k@ is larger or equal to @n@
ceilingLog2 :: Integer -> Log2
ceilingLog2 = Log2 . wrapper where
wrapper 0 = 0
wrapper n = 1 + go (n-1) where
go 0 = -1
go k = 1 + go (shiftR k 1)
exactLog2 :: Integer -> Maybe Log2
exactLog2 n = if re == n then Just log2 else Nothing where
log2 = ceilingLog2 n
re = exp2 log2
exactLog2_ :: Integer -> Log2
exactLog2_ n = case exactLog2 n of
Just log2 -> log2
Nothing -> error "exactLog2_: not a power of two"
exactLog2__ :: Int -> Log2
exactLog2__ = exactLog2_ . fromIntegral
--------------------------------------------------------------------------------
-- * Lists
-- | just to avoid the annoying GHC warning
myHead :: [a] -> a
myHead (x:_) = x
myHead [] = error "myHead: empty list"
safeZipWith :: (a -> b -> c) -> [a] -> [b] -> [c]
safeZipWith f = go where
go [] [] = []
go (x:xs) (y:ys) = f x y : go xs ys
go _ _ = error "safeZipWith: incompatible lengths"
safeZipWith3 :: (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
safeZipWith3 f = go where
go [] [] [] = []
go (x:xs) (y:ys) (z:zs) = f x y z : go xs ys zs
go _ _ _ = error "safeZipWith3: incompatible lengths"
safeZipWith4 :: (a -> b -> c -> d -> e) -> [a] -> [b] -> [c] -> [d] -> [e]
safeZipWith4 f = go where
go [] [] [] [] = []
go (x:xs) (y:ys) (z:zs) (w:ws) = f x y z w : go xs ys zs ws
go _ _ _ _ = error "safeZipWith4: incompatible lengths"
safeZip = safeZipWith (,)
safeZip3 = safeZipWith3 (,,)
safeZip4 = safeZipWith4 (,,,)
interleave :: [a] -> [a] -> [a]
interleave (x:xs) (y:ys) = x:y:interleave xs ys
interleave [] [] = []
interleave _ _ = error "interleave: expecting input lists of the same length"
partitionIntoChunks :: Int -> [a] -> [[a]]
partitionIntoChunks k = go where
go [] = []
go xs = take k xs : go (drop k xs)
nubOrd :: Ord a => [a] -> [a]
nubOrd = worker Set.empty where
worker _ [] = []
worker s (x:xs)
| Set.member x s = worker s xs
| otherwise = x : worker (Set.insert x s) xs
--------------------------------------------------------------------------------
-- * Monads
-- WTF, this is not in @base@ but 'Maybe' is?!
instance MonadFail (Either String) where
fail = Left
safeZipWithM :: forall m a b c. MonadFail m => (a -> b -> m c) -> [a] -> [b] -> m [c]
safeZipWithM f xs ys = go xs ys where
go :: [a] -> [b] -> m [c]
go (x:xs) (y:ys) = do
z <- f x y
zs <- go xs ys
return (z:zs)
go [] [] = return []
go _ _ = fail "safeZipWithM: incompatible input sizes"
safeZipWith3M :: MonadFail m => (a -> b -> c -> m d) -> [a] -> [b] -> [c] -> m [d]
safeZipWith3M f xs ys zs = go xs ys zs where
go (x:xs) (y:ys) (z:zs) = do
w <- f x y z
ws <- go xs ys zs
return (w:ws)
go [] [] [] = return []
go _ _ _ = fail "safeZipWith3M: incompatible input sizes"
safeZipWith4M :: MonadFail m => (a -> b -> c -> d -> m e) -> [a] -> [b] -> [c] -> [d] -> m [e]
safeZipWith4M f xs ys zs us = go xs ys zs us where
go (x:xs) (y:ys) (z:zs) (u:us) = do
w <- f x y z u
ws <- go xs ys zs us
return (w:ws)
go [] [] [] [] = return []
go _ _ _ _ = fail "safeZipWith4M: incompatible input sizes"
safeZipWithM_ :: MonadFail m => (a -> b -> m c) -> [a] -> [b] -> m ()
safeZipWithM_ f xs ys = void (safeZipWithM f xs ys)
safeZipWith3M_ :: MonadFail m => (a -> b -> c -> m d) -> [a] -> [b] -> [c] -> m ()
safeZipWith3M_ f xs ys zs = void (safeZipWith3M f xs ys zs)
safeZipWith4M_ :: MonadFail m => (a -> b -> c -> d -> m e) -> [a] -> [b] -> [c] -> [d] -> m ()
safeZipWith4M_ f xs ys zs us = void (safeZipWith4M f xs ys zs us)
safeFlippedZipWithM_ :: MonadFail m => [a] -> [b] -> (a -> b -> m c) -> m ()
safeFlippedZipWithM_ xs ys f = safeZipWithM_ f xs ys
safeFlippedZipWith3M_ :: MonadFail m => [a] -> [b] -> [c] -> (a -> b -> c -> m d) -> m ()
safeFlippedZipWith3M_ xs ys zs f = safeZipWith3M_ f xs ys zs
safeFlippedZipWith4M_ :: MonadFail m =>[a] -> [b] -> [c] -> [d] -> (a -> b -> c -> d -> m e) -> m ()
safeFlippedZipWith4M_ xs ys zs us f = safeZipWith4M_ f xs ys zs us
--------------------------------------------------------------------------------
-- * Arrays
singletonArray :: a -> Array Int a
singletonArray x = listArray (0,0) [x]
listToArray :: [a] -> Array Int a
listToArray xs = listArray (0, length xs - 1) xs
makeArray :: Int -> (Int -> a) -> Array Int a
makeArray n fun = listArray (0,n-1) [ fun i | i<-[0..n-1] ]
arrayLength :: Array Int a -> Int
arrayLength arr = b - a + 1 where (a,b) = bounds arr
-- | Synonym for 'arrayLength'
arraySize :: Array Int a -> Int
arraySize = arrayLength
-- | Returns the default value when out of range
safeIndex :: a -> Array Int a -> Int -> a
safeIndex def arr j
| j < a = def
| j > b = def
| otherwise = arr!j
where
(a,b) = bounds arr
interleaveArrays' :: Array Int (Array Int a) -> Array Int a
interleaveArrays' arrs
| nubOrd (elems sizes) == [n] = big
| otherwise = error "interleaveArrays': incompatible array sizes"
where
m = arraySize arrs
sizes = fmap arrayLength arrs
n = sizes!0
big = listArray (0,n*m-1) [ (arrs!j)!i | i<-[0..n-1] , j<-[0..m-1] ]
interleaveArrays :: [Array Int a] -> Array Int a
interleaveArrays arrayList = interleaveArrays' (listToArray arrayList)
-- | This is the inverse of @interleaveArrays@. The integer parameter is the number
-- of output vectors (or \"stride\")
untangleArray :: Int -> Array Int a -> [Array Int a]
untangleArray stride input
| r /= 0 = error "untangleArrays: input array's size is not divisible by the stride"
| otherwise = pieces
where
n = arraySize input
(q,r) = divMod n stride
pieces = [ extractCosetArray j stride input | j<-[0..stride-1] ]
{-
pieces =
[ listArray (0,q-1) [ input ! (j + i*stride) | i <- [0..q-1] ]
| j <- [0..stride-1]
]
-}
untangleArray' :: Int -> Array Int a -> Array Int (Array Int a)
untangleArray' stride = listToArray . untangleArray stride
-- | This extracts a subarray with indices of the form @[ offset + i*stride | i<-[0..n-1] ]@
extractCosetArray :: Int -> Int -> Array Int a -> Array Int a
extractCosetArray offset stride input
| r /= 0 = error "extractCosetArray: input array's size is not divisible by the stride"
| otherwise = piece
where
n = arraySize input
(q,r) = divMod n stride
piece = listArray (0,q-1) [ input ! (offset + i*stride) | i <- [0..q-1] ]
-- | This extracts a subarray with indices of the form @[ i*stride | i<-[0..n-1] ]@
extractSubgroupArray :: Int -> Array Int a -> Array Int a
extractSubgroupArray stride = extractCosetArray 0 stride
--------------------------------------------------------------------------------
-- * Binary encoding
instance Binary Log2 where
put (Log2 k) = putWord8 (fromIntegral k)
get = (Log2 . fromIntegral) <$> getWord8
putSmallList :: Binary a => [a] -> Put
putSmallList list = do
let n = length list
if (n < 256)
then do
putWord8 (fromIntegral n)
mapM_ put list
else error "putSmallList: array length >= 256"
getSmallList :: Binary a => Get [a]
getSmallList = do
len <- fromIntegral <$> getWord8 :: Get Int
replicateM len get
putSmallArray :: Binary a => Array Int a -> Put
putSmallArray list = do
let n = arrayLength list
if (n < 256)
then do
putWord8 (fromIntegral n)
mapM_ put list
else error "putSmallArray: array length >= 256"
getSmallArray :: Binary a => Get (Array Int a)
getSmallArray = do
len <- fromIntegral <$> getWord8 :: Get Int
listToArray <$> replicateM len get
--------------------------------------------------------------------------------