refactored the Haskell reference Skyscraper permutation code and added sponge

This commit is contained in:
Balazs Komuves 2025-02-12 21:26:11 +01:00
parent b45308432b
commit b65d13f888
No known key found for this signature in database
GPG Key ID: F63B7AEF18435562
3 changed files with 232 additions and 133 deletions

View File

@ -0,0 +1,105 @@
-- | Degree 2 and 3 field extension of the BN254 scalar field
{-# LANGUAGE BangPatterns, DataKinds, TypeFamilies #-}
module Skyscraper.FieldExt where
--------------------------------------------------------------------------------
import Data.Bits
import Data.Word
import Data.Proxy
import BN254
--------------------------------------------------------------------------------
type F2 = Ext2 F
type F3 = Ext3 F
--------------------------------------------------------------------------------
class Num ext => FieldExtension ext where
dimension :: Proxy ext -> Int
scale :: F -> ext -> ext
extToList :: ext -> [F]
extFromList :: [F] -> ext
instance FieldExtension F where
dimension = const 1
scale = (*)
extToList x = [x]
extFromList [x] = x
--------------------------------------------------------------------------------
-- | @F[X] / (X^2 + 5)@
data Ext2 a
= MkExt2 !a !a
deriving (Eq,Show)
negF2 :: F2 -> F2
negF2 (MkExt2 a0 a1) = MkExt2 (negate a0) (negate a1)
addF2 :: F2 -> F2 -> F2
addF2 (MkExt2 a0 a1) (MkExt2 b0 b1) = MkExt2 (a0+b0) (a1+b1)
mulF2 :: F2 -> F2 -> F2
mulF2 (MkExt2 a0 a1) (MkExt2 b0 b1) = MkExt2 c0 c1 where
c0 = a0*b0 - 5*a1*b1
c1 = a0*b1 + a1*b0
sclF2 :: F -> F2 -> F2
sclF2 s (MkExt2 a0 a1) = MkExt2 (s*a0) (s*a1)
instance Num F2 where
fromInteger x = MkExt2 (fromInteger x) 0
negate = negF2
(+) = addF2
(*) = mulF2
signum = error "Num/F2/signum"
abs = error "Num/F2/abs"
instance FieldExtension F2 where
dimension = const 2
scale = sclF2
extToList (MkExt2 x y) = [x,y]
extFromList [x,y] = MkExt2 x y
----------------------------------------
-- | @F[X] / (X^3 + 3)@
data Ext3 a
= MkExt3 !a !a !a
deriving (Eq,Show)
negF3 :: F3 -> F3
negF3 (MkExt3 a0 a1 a2) = MkExt3 (negate a0) (negate a1) (negate a2)
addF3 :: F3 -> F3 -> F3
addF3 (MkExt3 a0 a1 a2) (MkExt3 b0 b1 b2) = MkExt3 (a0+b0) (a1+b1) (a2+b2)
mulF3 :: F3 -> F3 -> F3
mulF3 (MkExt3 a0 a1 a2) (MkExt3 b0 b1 b2) = MkExt3 c0 c1 c2 where
c0 = a0*b0 - 3*a2*b1 - 3*a1*b2
c1 = a1*b0 + a0*b1 - 3*a2*b2
c2 = a2*b0 + a1*b1 + a0*b2
sclF3 :: F -> F3 -> F3
sclF3 s (MkExt3 a0 a1 a2) = MkExt3 (s*a0) (s*a1) (s*a2)
instance Num F3 where
fromInteger x = MkExt3 (fromInteger x) 0 0
negate = negF3
(+) = addF3
(*) = mulF3
signum = error "Num/F3/signum"
abs = error "Num/F3/abs"
instance FieldExtension F3 where
dimension = const 3
scale = sclF3
extToList (MkExt3 x y z) = [x,y,z]
extFromList [x,y,z] = MkExt3 x y z
--------------------------------------------------------------------------------

View File

@ -4,26 +4,30 @@
-- <https://extgit.isec.tugraz.at/krypto/zkfriendlyhashzoo>
--
{-# LANGUAGE TypeApplications, DataKinds #-}
module Skyscraper.KATs where
--------------------------------------------------------------------------------
import Data.Proxy
import BN254
import Skyscraper.FieldExt
import Skyscraper.Permutation
--------------------------------------------------------------------------------
testSkyscraperKATs :: [Bool]
testSkyscraperKATs = [ok1,ok2,ok3] where
ok1 = (perm1 input1 == output1)
ok2 = (perm2 input2 == output2)
ok3 = (perm3 input3 == output3)
ok1 = (permute (Proxy @Sky2) input1 == output1)
ok2 = (permute (Proxy @Sky4) input2 == output2)
ok3 = (permute (Proxy @Sky6) input3 == output3)
--------------------------------------------------------------------------------
-- Note: the official implementation stores the field extension vector in the "wrong" order
mkExt2 (b,a) = MkExt2 a b
mkExt3 (c,b,a) = MkExt3 a b c
revMkExt2 (b,a) = MkExt2 a b
revMkExt3 (c,b,a) = MkExt3 a b c
input1 :: (F,F)
input1 =
@ -39,26 +43,26 @@ output1 =
input2 :: (F2,F2)
input2 =
( mkExt2 (0x0004d2 , 0x00029a)
, mkExt2 (0x000162e, 0x000309)
( revMkExt2 (0x0004d2 , 0x00029a)
, revMkExt2 (0x000162e, 0x000309)
)
output2 :: (F2,F2)
output2 =
( mkExt2 ( 0x2456fa7300e7899364d1b7b933ee989a11606ad64b3166bb6e2822d46b2979db, 0x151084b3967b629e9103c5b85cf76bf47557d71a492e9575eb0dc2ac7bac0af3 )
, mkExt2 ( 0x1f7964c2b3b354824906659089ac272aaa84cb50214c200bcf674d677cc83ee2, 0x29c9023fff7db7812cde8269233632415bb56ea4c8644c1edcfdbde68cc29dfe )
( revMkExt2 ( 0x2456fa7300e7899364d1b7b933ee989a11606ad64b3166bb6e2822d46b2979db, 0x151084b3967b629e9103c5b85cf76bf47557d71a492e9575eb0dc2ac7bac0af3 )
, revMkExt2 ( 0x1f7964c2b3b354824906659089ac272aaa84cb50214c200bcf674d677cc83ee2, 0x29c9023fff7db7812cde8269233632415bb56ea4c8644c1edcfdbde68cc29dfe )
)
input3 :: (F3,F3)
input3 =
( mkExt3 (0x0004d2 , 0x00029a, 0x0003e9 )
, mkExt3 (0x000162e, 0x000309, 0x0007d2 )
( revMkExt3 (0x0004d2 , 0x00029a, 0x0003e9 )
, revMkExt3 (0x000162e, 0x000309, 0x0007d2 )
)
output3 :: (F3,F3)
output3 =
( mkExt3 ( 0x2ca9324e4d13668786f9f2dadf2ac3baf75f4bd57e14150c3421061d377edb6f, 0x2b87210ee7202515405c813b366da0b944c393e13332fd746ac19629c3b86486, 0x1bf5b2a7bed61ddd44f1d5a01492f203bdfd4973d68f3d91dddfdb8bc5b2db70 )
, mkExt3 ( 0x9896a6aa7e71659af7c11a42e6c95a361225befaf1613e7253c5224165ecaf5 , 0x1f40923d58dbe5ee7b9bd58ff493cd5141b0fd75da57434c41aeb3ddfcbe3c37, 0x26afdb536cae8d809d0609a51c463d80198d2530e32053c5e888301953dfe670 )
( revMkExt3 ( 0x2ca9324e4d13668786f9f2dadf2ac3baf75f4bd57e14150c3421061d377edb6f, 0x2b87210ee7202515405c813b366da0b944c393e13332fd746ac19629c3b86486, 0x1bf5b2a7bed61ddd44f1d5a01492f203bdfd4973d68f3d91dddfdb8bc5b2db70 )
, revMkExt3 ( 0x9896a6aa7e71659af7c11a42e6c95a361225befaf1613e7253c5224165ecaf5 , 0x1f40923d58dbe5ee7b9bd58ff493cd5141b0fd75da57434c41aeb3ddfcbe3c37, 0x26afdb536cae8d809d0609a51c463d80198d2530e32053c5e888301953dfe670 )
)
--------------------------------------------------------------------------------

View File

@ -1,5 +1,5 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE BangPatterns, DataKinds, TypeFamilies, ScopedTypeVariables, TypeApplications #-}
module Skyscraper.Permutation where
--------------------------------------------------------------------------------
@ -7,72 +7,135 @@ module Skyscraper.Permutation where
import Data.Bits
import Data.Word
import Data.Kind
import Data.Proxy
import BN254
import Skyscraper.RoundConst
import Skyscraper.FieldExt
--------------------------------------------------------------------------------
type F2 = Ext2 F
type F3 = Ext3 F
data SkyWidth
= Sky2 -- ^ 2 field elements
| Sky4 -- ^ degree 2 field extension (= 4 field elements)
| Sky6 -- ^ degree 3 field extension (= 6 field elements)
deriving (Eq,Show)
--------------------------------------------------------------------------------
-- | @F[X] / (X^2 + 5)@
data Ext2 a
= MkExt2 !a !a
deriving (Eq,Show)
type Digest = F
negF2 :: F2 -> F2
negF2 (MkExt2 a0 a1) = MkExt2 (negate a0) (negate a1)
type State sky = (FieldExt sky, FieldExt sky)
addF2 :: F2 -> F2 -> F2
addF2 (MkExt2 a0 a1) (MkExt2 b0 b1) = MkExt2 (a0+b0) (a1+b1)
stateFromList :: forall (sky :: SkyWidth). SkyscraperImpl sky => Proxy sky -> [F] -> State sky
stateFromList pxy xs = case splitAt (dimension $ Proxy @(FieldExt sky)) xs of
(us,vs) -> (extFromList us, extFromList vs)
mulF2 :: F2 -> F2 -> F2
mulF2 (MkExt2 a0 a1) (MkExt2 b0 b1) = MkExt2 c0 c1 where
c0 = a0*b0 - 5*a1*b1
c1 = a0*b1 + a1*b0
stateToList :: SkyscraperImpl sky => Proxy sky -> State sky -> [F]
stateToList pxy (x,y) = extToList x ++ extToList y
sclF2 :: F -> F2 -> F2
sclF2 s (MkExt2 a0 a1) = MkExt2 (s*a0) (s*a1)
extractDigest :: SkyscraperImpl sky => Proxy sky -> State sky -> F
extractDigest pxy = head . stateToList pxy
instance Num F2 where
fromInteger x = MkExt2 (fromInteger x) 0
negate = negF2
(+) = addF2
(*) = mulF2
signum = error "Num/F2/signum"
abs = error "Num/F2/abs"
initialState :: SkyscraperImpl sky => Proxy sky -> State sky
initialState pxy = stateFromList pxy $ replicate (rate pxy) 0 ++ [capacityIV pxy]
----------------------------------------
--------------------------------------------------------------------------------
-- | @F[X] / (X^3 + 3)@
data Ext3 a
= MkExt3 !a !a !a
deriving (Eq,Show)
class FieldExtension (FieldExt sky) => SkyscraperImpl (sky :: SkyWidth) where
negF3 :: F3 -> F3
negF3 (MkExt3 a0 a1 a2) = MkExt3 (negate a0) (negate a1) (negate a2)
type FieldExt sky :: Type
addF3 :: F3 -> F3 -> F3
addF3 (MkExt3 a0 a1 a2) (MkExt3 b0 b1 b2) = MkExt3 (a0+b0) (a1+b1) (a2+b2)
stateWidth :: Proxy sky -> Int
rate :: Proxy sky -> Int
capacity :: Proxy sky -> Int
capacityIV :: Proxy sky -> F
mulF3 :: F3 -> F3 -> F3
mulF3 (MkExt3 a0 a1 a2) (MkExt3 b0 b1 b2) = MkExt3 c0 c1 c2 where
c0 = a0*b0 - 3*a2*b1 - 3*a1*b2
c1 = a1*b0 + a0*b1 - 3*a2*b2
c2 = a2*b0 + a1*b1 + a0*b2
barsOnly :: Proxy sky -> FieldExt sky -> FieldExt sky
roundConst :: Proxy sky -> Int -> FieldExt sky
sclF3 :: F -> F3 -> F3
sclF3 s (MkExt3 a0 a1 a2) = MkExt3 (s*a0) (s*a1) (s*a2)
capacity pxy = 1
rate pxy = stateWidth pxy - capacity pxy
stateWidth pxy = capacity pxy + rate pxy
instance Num F3 where
fromInteger x = MkExt3 (fromInteger x) 0 0
negate = negF3
(+) = addF3
(*) = mulF3
signum = error "Num/F3/signum"
abs = error "Num/F3/abs"
--------------------------------------------------------------------------------
bars :: SkyscraperImpl sky => Proxy sky -> FieldExt sky -> State sky -> State sky
bars pxy rc (l,r) = (l',r') where
l' = r + barsOnly pxy l + rc
r' = l
sq :: SkyscraperImpl sky => Proxy sky -> FieldExt sky -> State sky -> State sky
sq pxy rc (l,r) = (l',r') where
l' = r + scale invMontMultiplier (l*l) + rc
r' = l
permute :: SkyscraperImpl sky => Proxy sky -> State sky -> State sky
permute pxy
= sq pxy (roundConst pxy 9)
. sq pxy (roundConst pxy 8)
. bars pxy (roundConst pxy 7)
. bars pxy (roundConst pxy 6)
. sq pxy (roundConst pxy 5)
. sq pxy (roundConst pxy 4)
. bars pxy (roundConst pxy 3)
. bars pxy (roundConst pxy 2)
. sq pxy (roundConst pxy 1)
. sq pxy (roundConst pxy 0)
--------------------------------------------------------------------------------
compress :: SkyscraperImpl sky => Proxy sky -> Digest -> Digest -> Digest
compress pxy x y = extractDigest pxy $ permute pxy $ input where
input = stateFromList pxy $ [x,y] ++ replicate (stateWidth pxy - 2) 0
spongeWithPad :: SkyscraperImpl sky => Proxy sky -> [F] -> Digest
spongeWithPad pxy xs = spongeNoPad pxy (xs ++ [1])
spongeNoPad :: forall (sky :: SkyWidth). SkyscraperImpl sky => Proxy sky -> [F] -> Digest
spongeNoPad _ [] = error "spongeNoPad: empty input"
spongeNoPad pxy list = extractDigest pxy $ go (initialState pxy) list where
go :: State sky -> [F] -> State sky
go state [] = state
go state what = case splitAt (rate pxy) what of
(this,rest) -> go (permute pxy (mixDataWithState pxy this state)) rest
mixDataWithState :: SkyscraperImpl sky => Proxy sky -> [F] -> State sky -> State sky
mixDataWithState pxy input state
| m == 0 = error "mixDataWithState: input is empty"
| m > r = error "mixDataWithState: input longer than the rate"
| otherwise = stateFromList pxy $ mix $ stateToList pxy state
where
r = rate pxy
m = length input
mix old = zipWith (+) old (input ++ repeat 0)
--------------------------------------------------------------------------------
instance SkyscraperImpl Sky2 where
type FieldExt Sky2 = F
stateWidth _ = 2
capacity _ = 1
capacityIV _ = 1
barsOnly _ = barsOnly1
roundConst _ = rc1
instance SkyscraperImpl Sky4 where
type FieldExt Sky4 = F2
stateWidth _ = 4
capacity _ = 1
capacityIV _ = 2
barsOnly _ = barsOnly2
roundConst _ = rc2
instance SkyscraperImpl Sky6 where
type FieldExt Sky6 = F3
stateWidth _ = 6
capacity _ = 1
capacityIV _ = 3
barsOnly _ = barsOnly3
roundConst _ = rc3
--------------------------------------------------------------------------------
@ -149,40 +212,6 @@ barsOnly3 (MkExt3 inp0 inp1 inp2) = MkExt3 out0 out1 out2 where
--------------------------------------------------------------------------------
bars1 :: F -> (F,F) -> (F,F)
bars1 rc (l,r) = (l',r') where
l' = r + barsOnly1 l + rc
r' = l
bars2 :: F2 -> (F2,F2) -> (F2,F2)
bars2 rc (l,r) = (l',r') where
l' = r + barsOnly2 l + rc
r' = l
bars3 :: F3 -> (F3,F3) -> (F3,F3)
bars3 rc (l,r) = (l',r') where
l' = r + barsOnly3 l + rc
r' = l
--------------------------------------------------------------------------------
sq1 :: F -> (F,F) -> (F,F)
sq1 rc (l,r) = (l',r') where
l' = r + invMontMultiplier*l*l + rc
r' = l
sq2 :: F2 -> (F2,F2) -> (F2,F2)
sq2 rc (l,r) = (l',r') where
l' = r + sclF2 invMontMultiplier (l*l) + rc
r' = l
sq3 :: F3 -> (F3,F3) -> (F3,F3)
sq3 rc (l,r) = (l',r') where
l' = r + sclF3 invMontMultiplier (l*l) + rc
r' = l
--------------------------------------------------------------------------------
lkpKst :: Int -> F
lkpKst k = roundConstF !! k
@ -211,43 +240,4 @@ rc3 k
--------------------------------------------------------------------------------
perm1 :: (F,F) -> (F,F)
perm1
= sq1 (rc1 9)
. sq1 (rc1 8)
. bars1 (rc1 7)
. bars1 (rc1 6)
. sq1 (rc1 5)
. sq1 (rc1 4)
. bars1 (rc1 3)
. bars1 (rc1 2)
. sq1 (rc1 1)
. sq1 (rc1 0)
perm2 :: (F2,F2) -> (F2,F2)
perm2
= sq2 (rc2 9)
. sq2 (rc2 8)
. bars2 (rc2 7)
. bars2 (rc2 6)
. sq2 (rc2 5)
. sq2 (rc2 4)
. bars2 (rc2 3)
. bars2 (rc2 2)
. sq2 (rc2 1)
. sq2 (rc2 0)
perm3 :: (F3,F3) -> (F3,F3)
perm3
= sq3 (rc3 9)
. sq3 (rc3 8)
. bars3 (rc3 7)
. bars3 (rc3 6)
. sq3 (rc3 5)
. sq3 (rc3 4)
. bars3 (rc3 3)
. bars3 (rc3 2)
. sq3 (rc3 1)
. sq3 (rc3 0)
--------------------------------------------------------------------------------