mirror of
https://github.com/logos-storage/rust-bn254-hash.git
synced 2026-01-02 22:03:13 +00:00
105 lines
2.6 KiB
Haskell
105 lines
2.6 KiB
Haskell
|
|
-- | The BN254 scalar field
|
|
|
|
{-# LANGUAGE BangPatterns #-}
|
|
module BN254 where
|
|
|
|
--------------------------------------------------------------------------------
|
|
|
|
import Data.Bits
|
|
import Data.Ratio
|
|
import Text.Printf
|
|
|
|
--------------------------------------------------------------------------------
|
|
|
|
newtype F = MkF Integer deriving (Eq,Show)
|
|
|
|
fromF :: F -> Integer
|
|
fromF (MkF x) = x
|
|
|
|
toF :: Integer -> F
|
|
toF = MkF . modP
|
|
|
|
fieldPrime :: Integer
|
|
fieldPrime = 21888242871839275222246405745257275088548364400416034343698204186575808495617
|
|
|
|
modP :: Integer -> Integer
|
|
modP x = mod x fieldPrime
|
|
|
|
instance Num F where
|
|
fromInteger = toF . fromInteger
|
|
negate (MkF x) = toF (negate x)
|
|
(+) (MkF x) (MkF y) = toF (x+y)
|
|
(-) (MkF x) (MkF y) = toF (x-y)
|
|
(*) (MkF x) (MkF y) = toF (x*y)
|
|
abs x = x
|
|
signum _ = toF 1
|
|
|
|
square :: F -> F
|
|
square x = x*x
|
|
|
|
--------------------------------------------------------------------------------
|
|
|
|
power :: F -> Integer -> F
|
|
power x0 exponent
|
|
| exponent < 0 = error "power: expecting positive exponent"
|
|
| otherwise = go 1 x0 exponent
|
|
where
|
|
go !acc _ 0 = acc
|
|
go !acc s e = go acc' s' (shiftR e 1) where
|
|
s' = s*s
|
|
acc' = if e .&. 1 == 0 then acc else acc*s
|
|
|
|
invNaive :: F -> F
|
|
invNaive x = power x (fieldPrime - 2)
|
|
|
|
inv = invNaive
|
|
|
|
--------------------------------------------------------------------------------
|
|
|
|
instance Fractional F where
|
|
fromRational q = fromInteger (numerator q) / fromInteger (denominator q)
|
|
recip = inv
|
|
(/) x y = x * inv y
|
|
|
|
--------------------------------------------------------------------------------
|
|
|
|
newtype Mont
|
|
= MkMont F
|
|
deriving (Eq,Show)
|
|
|
|
montMultiplier :: F
|
|
montMultiplier = toF (2^256)
|
|
|
|
invMontMultiplier :: F
|
|
invMontMultiplier = inv montMultiplier
|
|
|
|
toMont :: F -> Mont
|
|
toMont x = MkMont (x * montMultiplier)
|
|
|
|
fromMont :: Mont -> F
|
|
fromMont (MkMont y) = (y * invMontMultiplier)
|
|
|
|
instance Num Mont where
|
|
fromInteger = toMont . toF . fromInteger
|
|
negate (MkMont x) = MkMont (negate x)
|
|
(+) (MkMont x) (MkMont y) = MkMont (x+y)
|
|
(-) (MkMont x) (MkMont y) = MkMont (x-y)
|
|
(*) (MkMont x) (MkMont y) = MkMont (x*y*invMontMultiplier)
|
|
abs x = x
|
|
signum _ = MkMont montMultiplier
|
|
|
|
--------------------------------------------------------------------------------
|
|
|
|
class ShowHex a where
|
|
showHex :: a -> String
|
|
|
|
printHex :: ShowHex a => a -> IO ()
|
|
printHex x = putStrLn (showHex x)
|
|
|
|
instance ShowHex Integer where showHex = printf "0x%x"
|
|
instance ShowHex F where showHex (MkF x) = showHex x
|
|
instance ShowHex Mont where showHex (MkMont y) = showHex y
|
|
|
|
--------------------------------------------------------------------------------
|