initial attempt on a reference implementation of Skyscraper

This commit is contained in:
Balazs Komuves 2025-02-03 19:40:07 +01:00
parent 0db2f80df0
commit cf46381eed
No known key found for this signature in database
GPG Key ID: F63B7AEF18435562
3 changed files with 360 additions and 0 deletions

View File

@ -7,6 +7,8 @@ module BN254 where
--------------------------------------------------------------------------------
import Data.Bits
import Data.Ratio
import Text.Printf
--------------------------------------------------------------------------------
@ -48,4 +50,55 @@ power x0 exponent
s' = s*s
acc' = if e .&. 1 == 0 then acc else acc*s
invNaive :: F -> F
invNaive x = power x (fieldPrime - 2)
inv = invNaive
--------------------------------------------------------------------------------
instance Fractional F where
fromRational q = fromInteger (numerator q) / fromInteger (denominator q)
recip = inv
(/) x y = x * inv y
--------------------------------------------------------------------------------
newtype Mont
= MkMont F
deriving (Eq,Show)
montMultiplier :: F
montMultiplier = toF (2^256)
invMontMultiplier :: F
invMontMultiplier = inv montMultiplier
toMont :: F -> Mont
toMont x = MkMont (x * montMultiplier)
fromMont :: Mont -> F
fromMont (MkMont y) = (y * invMontMultiplier)
instance Num Mont where
fromInteger = toMont . toF . fromInteger
negate (MkMont x) = MkMont (negate x)
(+) (MkMont x) (MkMont y) = MkMont (x+y)
(-) (MkMont x) (MkMont y) = MkMont (x-y)
(*) (MkMont x) (MkMont y) = MkMont (x*y*invMontMultiplier)
abs x = x
signum _ = MkMont montMultiplier
--------------------------------------------------------------------------------
class ShowHex a where
showHex :: a -> String
printHex :: ShowHex a => a -> IO ()
printHex x = putStrLn (showHex x)
instance ShowHex Integer where showHex = printf "0x%x"
instance ShowHex F where showHex (MkF x) = showHex x
instance ShowHex Mont where showHex (MkMont y) = showHex y
--------------------------------------------------------------------------------

View File

@ -0,0 +1,253 @@
{-# LANGUAGE BangPatterns #-}
module Skyscraper.Permutation where
--------------------------------------------------------------------------------
import Data.Bits
import Data.Word
import BN254
import Skyscraper.RoundConst
--------------------------------------------------------------------------------
type F2 = Ext2 F
type F3 = Ext3 F
--------------------------------------------------------------------------------
-- | @F[X] / (X^2 + 5)@
data Ext2 a
= MkExt2 !a !a
deriving (Eq,Show)
negF2 :: F2 -> F2
negF2 (MkExt2 a0 a1) = MkExt2 (negate a0) (negate a1)
addF2 :: F2 -> F2 -> F2
addF2 (MkExt2 a0 a1) (MkExt2 b0 b1) = MkExt2 (a0+b0) (a1+b1)
mulF2 :: F2 -> F2 -> F2
mulF2 (MkExt2 a0 a1) (MkExt2 b0 b1) = MkExt2 c0 c1 where
c0 = a0*b0 - 5*a1*b1
c1 = a0*b1 + a1*b0
sclF2 :: F -> F2 -> F2
sclF2 s (MkExt2 a0 a1) = MkExt2 (s*a0) (s*a1)
instance Num F2 where
fromInteger x = MkExt2 (fromInteger x) 0
negate = negF2
(+) = addF2
(*) = mulF2
signum = error "Num/F2/signum"
abs = error "Num/F2/abs"
----------------------------------------
-- | @F[X] / (X^3 + 3)@
data Ext3 a
= MkExt3 !a !a !a
deriving (Eq,Show)
negF3 :: F3 -> F3
negF3 (MkExt3 a0 a1 a2) = MkExt3 (negate a0) (negate a1) (negate a2)
addF3 :: F3 -> F3 -> F3
addF3 (MkExt3 a0 a1 a2) (MkExt3 b0 b1 b2) = MkExt3 (a0+b0) (a1+b1) (a2+b2)
mulF3 :: F3 -> F3 -> F3
mulF3 (MkExt3 a0 a1 a2) (MkExt3 b0 b1 b2) = MkExt3 c0 c1 c2 where
c0 = a0*b0 - 3*a2*b1 - 3*a1*b2
c1 = a1*b0 + a0*b1 - 3*a2*b2
c2 = a2*b0 + a1*b1 + a0*b2
sclF3 :: F -> F3 -> F3
sclF3 s (MkExt3 a0 a1 a2) = MkExt3 (s*a0) (s*a1) (s*a2)
instance Num F3 where
fromInteger x = MkExt3 (fromInteger x) 0 0
negate = negF3
(+) = addF3
(*) = mulF3
signum = error "Num/F3/signum"
abs = error "Num/F3/abs"
--------------------------------------------------------------------------------
sboxByte :: Word8 -> Word8
sboxByte y = rol1 $ y `xor` (rol1 ny .&. rol2 y .&. rol3 y) where
ny = complement y
rol1 = flip rotateL 1
rol2 = flip rotateL 2
rol3 = flip rotateL 3
--------------------------------------------------------------------------------
integerToBytesLE :: Integer -> [Word8]
integerToBytesLE = go 32 where
go :: Int -> Integer -> [Word8]
go 0 0 = []
go 0 _ = error "integerToBytesLE: does not fit into 32 bytes"
go !k !x = fromInteger (x .&. 255) : go (k-1) (shiftR x 8)
integerToBytesBE :: Integer -> [Word8]
integerToBytesBE = reverse . integerToBytesLE
integerFromBytesLE :: [Word8] -> Integer
integerFromBytesLE = go where
go [] = 0
go (!x:xs) = fromIntegral x + shiftL (go xs) 8
integerFromBytesBE :: [Word8] -> Integer
integerFromBytesBE = integerFromBytesLE . reverse
fieldToBytes :: F -> [Word8]
fieldToBytes (MkF x) = integerToBytesBE x
fieldFromBytes :: [Word8] -> F
fieldFromBytes = toF . integerFromBytesBE
--------------------------------------------------------------------------------
rotateLeft :: Int -> [a] -> [a]
rotateLeft k xs = drop k xs ++ take k xs
partition :: Int -> [a] -> [[a]]
partition k = go where
go [] = []
go xs = take k xs : go (drop k xs)
--------------------------------------------------------------------------------
barsOnly1 :: F -> F
barsOnly1 input = output where
decomp = fieldToBytes input
rot = rotateLeft 16 decomp
sbox = map sboxByte rot
output = fieldFromBytes sbox
barsOnly2 :: Ext2 F -> Ext2 F
barsOnly2 (MkExt2 inp0 inp1) = MkExt2 out0 out1 where
decomp = fieldToBytes inp0 ++ fieldToBytes inp1
rot = rotateLeft 16 decomp
sbox = map sboxByte rot
[ys0,ys1] = partition 32 sbox
out0 = fieldFromBytes ys0
out1 = fieldFromBytes ys1
barsOnly3 :: Ext3 F -> Ext3 F
barsOnly3 (MkExt3 inp0 inp1 inp2) = MkExt3 out0 out1 out2 where
decomp = fieldToBytes inp0 ++ fieldToBytes inp1 ++ fieldToBytes inp2
rot = rotateLeft 16 decomp
sbox = map sboxByte rot
[ys0,ys1,ys2] = partition 32 sbox
out0 = fieldFromBytes ys0
out1 = fieldFromBytes ys1
out2 = fieldFromBytes ys2
--------------------------------------------------------------------------------
bars1 :: F -> (F,F) -> (F,F)
bars1 rc (l,r) = (l',r') where
l' = r + barsOnly1 l + rc
r' = l
bars2 :: F2 -> (F2,F2) -> (F2,F2)
bars2 rc (l,r) = (l',r') where
l' = r + barsOnly2 l + rc
r' = l
bars3 :: F3 -> (F3,F3) -> (F3,F3)
bars3 rc (l,r) = (l',r') where
l' = r + barsOnly3 l + rc
r' = l
--------------------------------------------------------------------------------
sq1 :: F -> (F,F) -> (F,F)
sq1 rc (l,r) = (l',r') where
l' = r + invMontMultiplier*l*l + rc
r' = l
sq2 :: F2 -> (F2,F2) -> (F2,F2)
sq2 rc (l,r) = (l',r') where
l' = r + sclF2 invMontMultiplier (l*l) + rc
r' = l
sq3 :: F3 -> (F3,F3) -> (F3,F3)
sq3 rc (l,r) = (l',r') where
l' = r + sclF3 invMontMultiplier (l*l) + rc
r' = l
--------------------------------------------------------------------------------
lkpKst :: Int -> F
lkpKst k = roundConstF !! k
rc1 :: Int -> F
rc1 k
| k == 0 || k == 9 = 0
| k < 0 || k > 9 = error "rc1: round counter out of range"
| otherwise = let i = k-1 in lkpKst i
rc2 :: Int -> F2
rc2 k
| k == 0 || k == 9 = 0
| k < 0 || k > 9 = error "rc2: round counter out of range"
| otherwise = let i = k-1 in MkExt2
(lkpKst (2*i ))
(lkpKst (2*i+1))
rc3 :: Int -> F3
rc3 k
| k == 0 || k == 9 = 0
| k < 0 || k > 9 = error "rc3: round counter out of range"
| otherwise = let i = k-1 in MkExt3
(lkpKst (3*i ))
(lkpKst (3*i+1))
(lkpKst (3*i+2))
--------------------------------------------------------------------------------
perm1 :: (F,F) -> (F,F)
perm1
= sq1 (rc1 9)
. sq1 (rc1 8)
. bars1 (rc1 7)
. bars1 (rc1 6)
. sq1 (rc1 5)
. sq1 (rc1 4)
. bars1 (rc1 3)
. bars1 (rc1 2)
. sq1 (rc1 1)
. sq1 (rc1 0)
perm2 :: (F2,F2) -> (F2,F2)
perm2
= sq2 (rc2 9)
. sq2 (rc2 8)
. bars2 (rc2 7)
. bars2 (rc2 6)
. sq2 (rc2 5)
. sq2 (rc2 4)
. bars2 (rc2 3)
. bars2 (rc2 2)
. sq2 (rc2 1)
. sq2 (rc2 0)
perm3 :: (F3,F3) -> (F3,F3)
perm3
= sq3 (rc3 9)
. sq3 (rc3 8)
. bars3 (rc3 7)
. bars3 (rc3 6)
. sq3 (rc3 5)
. sq3 (rc3 4)
. bars3 (rc3 3)
. bars3 (rc3 2)
. sq3 (rc3 1)
. sq3 (rc3 0)
--------------------------------------------------------------------------------

View File

@ -0,0 +1,54 @@
module Skyscraper.RoundConst where
--------------------------------------------------------------------------------
import BN254
--------------------------------------------------------------------------------
-- | Prelimilary round constants are generated as:
--
-- > NUMindex = SHA256( index (32bit) || "Skyscraper" (28 byte string) )
--
-- More precisely: the index 0..23 is 32 bit /big-endian/, and the ASCII string
-- @"Skyscraper"@ os padded by zero bytes, so that the total input length is
-- 32 bytes
--
-- The resulting 32 byte data is then interpreted as big-endian 256 bit numbers.
--
roundConstBigInt :: [Integer]
roundConstBigInt =
[ 17829420340877239108687448009732280677191990375576158938221412342251481978692
, 27740342931201890067831390843279536630457710544396725670188095857896839417202
, 17048088173265532689680903955395019356591870902241717143279822196003888806966
, 109512792282736997633398631034649037613028427788284511060520396554381700616124
, 23518768991468467328187394347260979305359711922005254253047385842741274989784
, 95360373645575887695357714105933674592754581048282220961740831584356266637451
, 57106046715138585370392400429108362862843547132381623658436718362793140581845
, 16971509144034029782226530622087626979814683266929655790026304723118124142299
, 8608910393531852188108777530736778805001620473682472554749734455948859886057
, 54566392379700209585884878067585451869449334062644668287971552334629853792764
, 18708129585851494907644197977764586873688181219062643217509404046560774277231
, 52159802752268413629255578890890486811485406260370834838036769779232029980820
, 98108525134123848500172941527936985409086875223276518679938863940387852105202
, 105831033594660236721345339515389948186304708551041643590872398526195732523291
, 53084450331558915295247017186532447841918727727492403087452333633170905880952
, 78730946611419899835403512890231154575719512053287438310527615801825503526967
, 62089842541186043938517187437087053794210809382724083686360536771123796704819
, 32303085017979849099049635709265581104054174293154472699090841350494692332148
, 19361794324495443451354916303398190341881571975219162871160427826227778850994
, 65021267664773559966759214868166670507376995901124257419858229816098767301789
, 94847021352352647235478120180321422709509900436733319635143815989658015262598
, 51591271359432809566841356156562526830388219805637947403945613063492005256674
, 44534956566050763472510245910556224585100739093572801527559057220740673520964
, 84085239597197409225577945757724209425761279846653606664394225962327262179862
]
--------------------------------------------------------------------------------
roundConstF :: [F]
roundConstF = map toF roundConstBigInt
--------------------------------------------------------------------------------