From cf46381eed466257a98111e7df9feafbacbd60f2 Mon Sep 17 00:00:00 2001 From: Balazs Komuves Date: Mon, 3 Feb 2025 19:40:07 +0100 Subject: [PATCH] initial attempt on a reference implementation of Skyscraper --- reference/src/BN254.hs | 53 +++++ reference/src/Skyscraper/Permutation.hs | 253 ++++++++++++++++++++++++ reference/src/Skyscraper/RoundConst.hs | 54 +++++ 3 files changed, 360 insertions(+) create mode 100644 reference/src/Skyscraper/Permutation.hs create mode 100644 reference/src/Skyscraper/RoundConst.hs diff --git a/reference/src/BN254.hs b/reference/src/BN254.hs index 431aef3..d8ca071 100644 --- a/reference/src/BN254.hs +++ b/reference/src/BN254.hs @@ -7,6 +7,8 @@ module BN254 where -------------------------------------------------------------------------------- import Data.Bits +import Data.Ratio +import Text.Printf -------------------------------------------------------------------------------- @@ -48,4 +50,55 @@ power x0 exponent s' = s*s acc' = if e .&. 1 == 0 then acc else acc*s +invNaive :: F -> F +invNaive x = power x (fieldPrime - 2) + +inv = invNaive + +-------------------------------------------------------------------------------- + +instance Fractional F where + fromRational q = fromInteger (numerator q) / fromInteger (denominator q) + recip = inv + (/) x y = x * inv y + +-------------------------------------------------------------------------------- + +newtype Mont + = MkMont F + deriving (Eq,Show) + +montMultiplier :: F +montMultiplier = toF (2^256) + +invMontMultiplier :: F +invMontMultiplier = inv montMultiplier + +toMont :: F -> Mont +toMont x = MkMont (x * montMultiplier) + +fromMont :: Mont -> F +fromMont (MkMont y) = (y * invMontMultiplier) + +instance Num Mont where + fromInteger = toMont . toF . fromInteger + negate (MkMont x) = MkMont (negate x) + (+) (MkMont x) (MkMont y) = MkMont (x+y) + (-) (MkMont x) (MkMont y) = MkMont (x-y) + (*) (MkMont x) (MkMont y) = MkMont (x*y*invMontMultiplier) + abs x = x + signum _ = MkMont montMultiplier + +-------------------------------------------------------------------------------- + +class ShowHex a where + showHex :: a -> String + +printHex :: ShowHex a => a -> IO () +printHex x = putStrLn (showHex x) + +instance ShowHex Integer where showHex = printf "0x%x" +instance ShowHex F where showHex (MkF x) = showHex x +instance ShowHex Mont where showHex (MkMont y) = showHex y + -------------------------------------------------------------------------------- diff --git a/reference/src/Skyscraper/Permutation.hs b/reference/src/Skyscraper/Permutation.hs new file mode 100644 index 0000000..5a95a97 --- /dev/null +++ b/reference/src/Skyscraper/Permutation.hs @@ -0,0 +1,253 @@ + +{-# LANGUAGE BangPatterns #-} +module Skyscraper.Permutation where + +-------------------------------------------------------------------------------- + +import Data.Bits +import Data.Word + +import BN254 +import Skyscraper.RoundConst + +-------------------------------------------------------------------------------- + +type F2 = Ext2 F +type F3 = Ext3 F + +-------------------------------------------------------------------------------- + +-- | @F[X] / (X^2 + 5)@ +data Ext2 a + = MkExt2 !a !a + deriving (Eq,Show) + +negF2 :: F2 -> F2 +negF2 (MkExt2 a0 a1) = MkExt2 (negate a0) (negate a1) + +addF2 :: F2 -> F2 -> F2 +addF2 (MkExt2 a0 a1) (MkExt2 b0 b1) = MkExt2 (a0+b0) (a1+b1) + +mulF2 :: F2 -> F2 -> F2 +mulF2 (MkExt2 a0 a1) (MkExt2 b0 b1) = MkExt2 c0 c1 where + c0 = a0*b0 - 5*a1*b1 + c1 = a0*b1 + a1*b0 + +sclF2 :: F -> F2 -> F2 +sclF2 s (MkExt2 a0 a1) = MkExt2 (s*a0) (s*a1) + +instance Num F2 where + fromInteger x = MkExt2 (fromInteger x) 0 + negate = negF2 + (+) = addF2 + (*) = mulF2 + signum = error "Num/F2/signum" + abs = error "Num/F2/abs" + +---------------------------------------- + +-- | @F[X] / (X^3 + 3)@ +data Ext3 a + = MkExt3 !a !a !a + deriving (Eq,Show) + +negF3 :: F3 -> F3 +negF3 (MkExt3 a0 a1 a2) = MkExt3 (negate a0) (negate a1) (negate a2) + +addF3 :: F3 -> F3 -> F3 +addF3 (MkExt3 a0 a1 a2) (MkExt3 b0 b1 b2) = MkExt3 (a0+b0) (a1+b1) (a2+b2) + +mulF3 :: F3 -> F3 -> F3 +mulF3 (MkExt3 a0 a1 a2) (MkExt3 b0 b1 b2) = MkExt3 c0 c1 c2 where + c0 = a0*b0 - 3*a2*b1 - 3*a1*b2 + c1 = a1*b0 + a0*b1 - 3*a2*b2 + c2 = a2*b0 + a1*b1 + a0*b2 + +sclF3 :: F -> F3 -> F3 +sclF3 s (MkExt3 a0 a1 a2) = MkExt3 (s*a0) (s*a1) (s*a2) + +instance Num F3 where + fromInteger x = MkExt3 (fromInteger x) 0 0 + negate = negF3 + (+) = addF3 + (*) = mulF3 + signum = error "Num/F3/signum" + abs = error "Num/F3/abs" + +-------------------------------------------------------------------------------- + +sboxByte :: Word8 -> Word8 +sboxByte y = rol1 $ y `xor` (rol1 ny .&. rol2 y .&. rol3 y) where + ny = complement y + rol1 = flip rotateL 1 + rol2 = flip rotateL 2 + rol3 = flip rotateL 3 + +-------------------------------------------------------------------------------- + +integerToBytesLE :: Integer -> [Word8] +integerToBytesLE = go 32 where + go :: Int -> Integer -> [Word8] + go 0 0 = [] + go 0 _ = error "integerToBytesLE: does not fit into 32 bytes" + go !k !x = fromInteger (x .&. 255) : go (k-1) (shiftR x 8) + +integerToBytesBE :: Integer -> [Word8] +integerToBytesBE = reverse . integerToBytesLE + +integerFromBytesLE :: [Word8] -> Integer +integerFromBytesLE = go where + go [] = 0 + go (!x:xs) = fromIntegral x + shiftL (go xs) 8 + +integerFromBytesBE :: [Word8] -> Integer +integerFromBytesBE = integerFromBytesLE . reverse + +fieldToBytes :: F -> [Word8] +fieldToBytes (MkF x) = integerToBytesBE x + +fieldFromBytes :: [Word8] -> F +fieldFromBytes = toF . integerFromBytesBE + +-------------------------------------------------------------------------------- + +rotateLeft :: Int -> [a] -> [a] +rotateLeft k xs = drop k xs ++ take k xs + +partition :: Int -> [a] -> [[a]] +partition k = go where + go [] = [] + go xs = take k xs : go (drop k xs) + +-------------------------------------------------------------------------------- + +barsOnly1 :: F -> F +barsOnly1 input = output where + decomp = fieldToBytes input + rot = rotateLeft 16 decomp + sbox = map sboxByte rot + output = fieldFromBytes sbox + +barsOnly2 :: Ext2 F -> Ext2 F +barsOnly2 (MkExt2 inp0 inp1) = MkExt2 out0 out1 where + decomp = fieldToBytes inp0 ++ fieldToBytes inp1 + rot = rotateLeft 16 decomp + sbox = map sboxByte rot + [ys0,ys1] = partition 32 sbox + out0 = fieldFromBytes ys0 + out1 = fieldFromBytes ys1 + +barsOnly3 :: Ext3 F -> Ext3 F +barsOnly3 (MkExt3 inp0 inp1 inp2) = MkExt3 out0 out1 out2 where + decomp = fieldToBytes inp0 ++ fieldToBytes inp1 ++ fieldToBytes inp2 + rot = rotateLeft 16 decomp + sbox = map sboxByte rot + [ys0,ys1,ys2] = partition 32 sbox + out0 = fieldFromBytes ys0 + out1 = fieldFromBytes ys1 + out2 = fieldFromBytes ys2 + +-------------------------------------------------------------------------------- + +bars1 :: F -> (F,F) -> (F,F) +bars1 rc (l,r) = (l',r') where + l' = r + barsOnly1 l + rc + r' = l + +bars2 :: F2 -> (F2,F2) -> (F2,F2) +bars2 rc (l,r) = (l',r') where + l' = r + barsOnly2 l + rc + r' = l + +bars3 :: F3 -> (F3,F3) -> (F3,F3) +bars3 rc (l,r) = (l',r') where + l' = r + barsOnly3 l + rc + r' = l + +-------------------------------------------------------------------------------- + +sq1 :: F -> (F,F) -> (F,F) +sq1 rc (l,r) = (l',r') where + l' = r + invMontMultiplier*l*l + rc + r' = l + +sq2 :: F2 -> (F2,F2) -> (F2,F2) +sq2 rc (l,r) = (l',r') where + l' = r + sclF2 invMontMultiplier (l*l) + rc + r' = l + +sq3 :: F3 -> (F3,F3) -> (F3,F3) +sq3 rc (l,r) = (l',r') where + l' = r + sclF3 invMontMultiplier (l*l) + rc + r' = l + +-------------------------------------------------------------------------------- + +lkpKst :: Int -> F +lkpKst k = roundConstF !! k + +rc1 :: Int -> F +rc1 k + | k == 0 || k == 9 = 0 + | k < 0 || k > 9 = error "rc1: round counter out of range" + | otherwise = let i = k-1 in lkpKst i + +rc2 :: Int -> F2 +rc2 k + | k == 0 || k == 9 = 0 + | k < 0 || k > 9 = error "rc2: round counter out of range" + | otherwise = let i = k-1 in MkExt2 + (lkpKst (2*i )) + (lkpKst (2*i+1)) + +rc3 :: Int -> F3 +rc3 k + | k == 0 || k == 9 = 0 + | k < 0 || k > 9 = error "rc3: round counter out of range" + | otherwise = let i = k-1 in MkExt3 + (lkpKst (3*i )) + (lkpKst (3*i+1)) + (lkpKst (3*i+2)) + +-------------------------------------------------------------------------------- + +perm1 :: (F,F) -> (F,F) +perm1 + = sq1 (rc1 9) + . sq1 (rc1 8) + . bars1 (rc1 7) + . bars1 (rc1 6) + . sq1 (rc1 5) + . sq1 (rc1 4) + . bars1 (rc1 3) + . bars1 (rc1 2) + . sq1 (rc1 1) + . sq1 (rc1 0) + +perm2 :: (F2,F2) -> (F2,F2) +perm2 + = sq2 (rc2 9) + . sq2 (rc2 8) + . bars2 (rc2 7) + . bars2 (rc2 6) + . sq2 (rc2 5) + . sq2 (rc2 4) + . bars2 (rc2 3) + . bars2 (rc2 2) + . sq2 (rc2 1) + . sq2 (rc2 0) + +perm3 :: (F3,F3) -> (F3,F3) +perm3 + = sq3 (rc3 9) + . sq3 (rc3 8) + . bars3 (rc3 7) + . bars3 (rc3 6) + . sq3 (rc3 5) + . sq3 (rc3 4) + . bars3 (rc3 3) + . bars3 (rc3 2) + . sq3 (rc3 1) + . sq3 (rc3 0) + +-------------------------------------------------------------------------------- diff --git a/reference/src/Skyscraper/RoundConst.hs b/reference/src/Skyscraper/RoundConst.hs new file mode 100644 index 0000000..78340e7 --- /dev/null +++ b/reference/src/Skyscraper/RoundConst.hs @@ -0,0 +1,54 @@ + +module Skyscraper.RoundConst where + +-------------------------------------------------------------------------------- + +import BN254 + +-------------------------------------------------------------------------------- + +-- | Prelimilary round constants are generated as: +-- +-- > NUMindex = SHA256( index (32bit) || "Skyscraper" (28 byte string) ) +-- +-- More precisely: the index 0..23 is 32 bit /big-endian/, and the ASCII string +-- @"Skyscraper"@ os padded by zero bytes, so that the total input length is +-- 32 bytes +-- +-- The resulting 32 byte data is then interpreted as big-endian 256 bit numbers. +-- +roundConstBigInt :: [Integer] +roundConstBigInt = + [ 17829420340877239108687448009732280677191990375576158938221412342251481978692 + , 27740342931201890067831390843279536630457710544396725670188095857896839417202 + , 17048088173265532689680903955395019356591870902241717143279822196003888806966 + , 109512792282736997633398631034649037613028427788284511060520396554381700616124 + , 23518768991468467328187394347260979305359711922005254253047385842741274989784 + , 95360373645575887695357714105933674592754581048282220961740831584356266637451 + , 57106046715138585370392400429108362862843547132381623658436718362793140581845 + , 16971509144034029782226530622087626979814683266929655790026304723118124142299 + , 8608910393531852188108777530736778805001620473682472554749734455948859886057 + , 54566392379700209585884878067585451869449334062644668287971552334629853792764 + , 18708129585851494907644197977764586873688181219062643217509404046560774277231 + , 52159802752268413629255578890890486811485406260370834838036769779232029980820 + , 98108525134123848500172941527936985409086875223276518679938863940387852105202 + , 105831033594660236721345339515389948186304708551041643590872398526195732523291 + , 53084450331558915295247017186532447841918727727492403087452333633170905880952 + , 78730946611419899835403512890231154575719512053287438310527615801825503526967 + , 62089842541186043938517187437087053794210809382724083686360536771123796704819 + , 32303085017979849099049635709265581104054174293154472699090841350494692332148 + , 19361794324495443451354916303398190341881571975219162871160427826227778850994 + , 65021267664773559966759214868166670507376995901124257419858229816098767301789 + , 94847021352352647235478120180321422709509900436733319635143815989658015262598 + , 51591271359432809566841356156562526830388219805637947403945613063492005256674 + , 44534956566050763472510245910556224585100739093572801527559057220740673520964 + , 84085239597197409225577945757724209425761279846653606664394225962327262179862 + ] + +-------------------------------------------------------------------------------- + +roundConstF :: [F] +roundConstF = map toF roundConstBigInt + +-------------------------------------------------------------------------------- +