105 lines
2.6 KiB
Haskell

-- | The BN254 scalar field
{-# LANGUAGE BangPatterns #-}
module BN254 where
--------------------------------------------------------------------------------
import Data.Bits
import Data.Ratio
import Text.Printf
--------------------------------------------------------------------------------
newtype F = MkF Integer deriving (Eq,Show)
fromF :: F -> Integer
fromF (MkF x) = x
toF :: Integer -> F
toF = MkF . modP
fieldPrime :: Integer
fieldPrime = 21888242871839275222246405745257275088548364400416034343698204186575808495617
modP :: Integer -> Integer
modP x = mod x fieldPrime
instance Num F where
fromInteger = toF . fromInteger
negate (MkF x) = toF (negate x)
(+) (MkF x) (MkF y) = toF (x+y)
(-) (MkF x) (MkF y) = toF (x-y)
(*) (MkF x) (MkF y) = toF (x*y)
abs x = x
signum _ = toF 1
square :: F -> F
square x = x*x
--------------------------------------------------------------------------------
power :: F -> Integer -> F
power x0 exponent
| exponent < 0 = error "power: expecting positive exponent"
| otherwise = go 1 x0 exponent
where
go !acc _ 0 = acc
go !acc s e = go acc' s' (shiftR e 1) where
s' = s*s
acc' = if e .&. 1 == 0 then acc else acc*s
invNaive :: F -> F
invNaive x = power x (fieldPrime - 2)
inv = invNaive
--------------------------------------------------------------------------------
instance Fractional F where
fromRational q = fromInteger (numerator q) / fromInteger (denominator q)
recip = inv
(/) x y = x * inv y
--------------------------------------------------------------------------------
newtype Mont
= MkMont F
deriving (Eq,Show)
montMultiplier :: F
montMultiplier = toF (2^256)
invMontMultiplier :: F
invMontMultiplier = inv montMultiplier
toMont :: F -> Mont
toMont x = MkMont (x * montMultiplier)
fromMont :: Mont -> F
fromMont (MkMont y) = (y * invMontMultiplier)
instance Num Mont where
fromInteger = toMont . toF . fromInteger
negate (MkMont x) = MkMont (negate x)
(+) (MkMont x) (MkMont y) = MkMont (x+y)
(-) (MkMont x) (MkMont y) = MkMont (x-y)
(*) (MkMont x) (MkMont y) = MkMont (x*y*invMontMultiplier)
abs x = x
signum _ = MkMont montMultiplier
--------------------------------------------------------------------------------
class ShowHex a where
showHex :: a -> String
printHex :: ShowHex a => a -> IO ()
printHex x = putStrLn (showHex x)
instance ShowHex Integer where showHex = printf "0x%x"
instance ShowHex F where showHex (MkF x) = showHex x
instance ShowHex Mont where showHex (MkMont y) = showHex y
--------------------------------------------------------------------------------