diff --git a/reference/src/Skyscraper/FieldExt.hs b/reference/src/Skyscraper/FieldExt.hs new file mode 100644 index 0000000..555f2ca --- /dev/null +++ b/reference/src/Skyscraper/FieldExt.hs @@ -0,0 +1,105 @@ + +-- | Degree 2 and 3 field extension of the BN254 scalar field + +{-# LANGUAGE BangPatterns, DataKinds, TypeFamilies #-} +module Skyscraper.FieldExt where + +-------------------------------------------------------------------------------- + +import Data.Bits +import Data.Word +import Data.Proxy + +import BN254 + +-------------------------------------------------------------------------------- + +type F2 = Ext2 F +type F3 = Ext3 F + +-------------------------------------------------------------------------------- + +class Num ext => FieldExtension ext where + dimension :: Proxy ext -> Int + scale :: F -> ext -> ext + extToList :: ext -> [F] + extFromList :: [F] -> ext + +instance FieldExtension F where + dimension = const 1 + scale = (*) + extToList x = [x] + extFromList [x] = x + +-------------------------------------------------------------------------------- + +-- | @F[X] / (X^2 + 5)@ +data Ext2 a + = MkExt2 !a !a + deriving (Eq,Show) + +negF2 :: F2 -> F2 +negF2 (MkExt2 a0 a1) = MkExt2 (negate a0) (negate a1) + +addF2 :: F2 -> F2 -> F2 +addF2 (MkExt2 a0 a1) (MkExt2 b0 b1) = MkExt2 (a0+b0) (a1+b1) + +mulF2 :: F2 -> F2 -> F2 +mulF2 (MkExt2 a0 a1) (MkExt2 b0 b1) = MkExt2 c0 c1 where + c0 = a0*b0 - 5*a1*b1 + c1 = a0*b1 + a1*b0 + +sclF2 :: F -> F2 -> F2 +sclF2 s (MkExt2 a0 a1) = MkExt2 (s*a0) (s*a1) + +instance Num F2 where + fromInteger x = MkExt2 (fromInteger x) 0 + negate = negF2 + (+) = addF2 + (*) = mulF2 + signum = error "Num/F2/signum" + abs = error "Num/F2/abs" + +instance FieldExtension F2 where + dimension = const 2 + scale = sclF2 + extToList (MkExt2 x y) = [x,y] + extFromList [x,y] = MkExt2 x y + +---------------------------------------- + +-- | @F[X] / (X^3 + 3)@ +data Ext3 a + = MkExt3 !a !a !a + deriving (Eq,Show) + +negF3 :: F3 -> F3 +negF3 (MkExt3 a0 a1 a2) = MkExt3 (negate a0) (negate a1) (negate a2) + +addF3 :: F3 -> F3 -> F3 +addF3 (MkExt3 a0 a1 a2) (MkExt3 b0 b1 b2) = MkExt3 (a0+b0) (a1+b1) (a2+b2) + +mulF3 :: F3 -> F3 -> F3 +mulF3 (MkExt3 a0 a1 a2) (MkExt3 b0 b1 b2) = MkExt3 c0 c1 c2 where + c0 = a0*b0 - 3*a2*b1 - 3*a1*b2 + c1 = a1*b0 + a0*b1 - 3*a2*b2 + c2 = a2*b0 + a1*b1 + a0*b2 + +sclF3 :: F -> F3 -> F3 +sclF3 s (MkExt3 a0 a1 a2) = MkExt3 (s*a0) (s*a1) (s*a2) + +instance Num F3 where + fromInteger x = MkExt3 (fromInteger x) 0 0 + negate = negF3 + (+) = addF3 + (*) = mulF3 + signum = error "Num/F3/signum" + abs = error "Num/F3/abs" + +instance FieldExtension F3 where + dimension = const 3 + scale = sclF3 + extToList (MkExt3 x y z) = [x,y,z] + extFromList [x,y,z] = MkExt3 x y z + +-------------------------------------------------------------------------------- diff --git a/reference/src/Skyscraper/KATs.hs b/reference/src/Skyscraper/KATs.hs index f0c5d55..5472dd8 100644 --- a/reference/src/Skyscraper/KATs.hs +++ b/reference/src/Skyscraper/KATs.hs @@ -4,26 +4,30 @@ -- -- +{-# LANGUAGE TypeApplications, DataKinds #-} module Skyscraper.KATs where -------------------------------------------------------------------------------- +import Data.Proxy + import BN254 +import Skyscraper.FieldExt import Skyscraper.Permutation -------------------------------------------------------------------------------- testSkyscraperKATs :: [Bool] testSkyscraperKATs = [ok1,ok2,ok3] where - ok1 = (perm1 input1 == output1) - ok2 = (perm2 input2 == output2) - ok3 = (perm3 input3 == output3) + ok1 = (permute (Proxy @Sky2) input1 == output1) + ok2 = (permute (Proxy @Sky4) input2 == output2) + ok3 = (permute (Proxy @Sky6) input3 == output3) -------------------------------------------------------------------------------- -- Note: the official implementation stores the field extension vector in the "wrong" order -mkExt2 (b,a) = MkExt2 a b -mkExt3 (c,b,a) = MkExt3 a b c +revMkExt2 (b,a) = MkExt2 a b +revMkExt3 (c,b,a) = MkExt3 a b c input1 :: (F,F) input1 = @@ -39,26 +43,26 @@ output1 = input2 :: (F2,F2) input2 = - ( mkExt2 (0x0004d2 , 0x00029a) - , mkExt2 (0x000162e, 0x000309) + ( revMkExt2 (0x0004d2 , 0x00029a) + , revMkExt2 (0x000162e, 0x000309) ) output2 :: (F2,F2) output2 = - ( mkExt2 ( 0x2456fa7300e7899364d1b7b933ee989a11606ad64b3166bb6e2822d46b2979db, 0x151084b3967b629e9103c5b85cf76bf47557d71a492e9575eb0dc2ac7bac0af3 ) - , mkExt2 ( 0x1f7964c2b3b354824906659089ac272aaa84cb50214c200bcf674d677cc83ee2, 0x29c9023fff7db7812cde8269233632415bb56ea4c8644c1edcfdbde68cc29dfe ) + ( revMkExt2 ( 0x2456fa7300e7899364d1b7b933ee989a11606ad64b3166bb6e2822d46b2979db, 0x151084b3967b629e9103c5b85cf76bf47557d71a492e9575eb0dc2ac7bac0af3 ) + , revMkExt2 ( 0x1f7964c2b3b354824906659089ac272aaa84cb50214c200bcf674d677cc83ee2, 0x29c9023fff7db7812cde8269233632415bb56ea4c8644c1edcfdbde68cc29dfe ) ) input3 :: (F3,F3) input3 = - ( mkExt3 (0x0004d2 , 0x00029a, 0x0003e9 ) - , mkExt3 (0x000162e, 0x000309, 0x0007d2 ) + ( revMkExt3 (0x0004d2 , 0x00029a, 0x0003e9 ) + , revMkExt3 (0x000162e, 0x000309, 0x0007d2 ) ) output3 :: (F3,F3) output3 = - ( mkExt3 ( 0x2ca9324e4d13668786f9f2dadf2ac3baf75f4bd57e14150c3421061d377edb6f, 0x2b87210ee7202515405c813b366da0b944c393e13332fd746ac19629c3b86486, 0x1bf5b2a7bed61ddd44f1d5a01492f203bdfd4973d68f3d91dddfdb8bc5b2db70 ) - , mkExt3 ( 0x9896a6aa7e71659af7c11a42e6c95a361225befaf1613e7253c5224165ecaf5 , 0x1f40923d58dbe5ee7b9bd58ff493cd5141b0fd75da57434c41aeb3ddfcbe3c37, 0x26afdb536cae8d809d0609a51c463d80198d2530e32053c5e888301953dfe670 ) + ( revMkExt3 ( 0x2ca9324e4d13668786f9f2dadf2ac3baf75f4bd57e14150c3421061d377edb6f, 0x2b87210ee7202515405c813b366da0b944c393e13332fd746ac19629c3b86486, 0x1bf5b2a7bed61ddd44f1d5a01492f203bdfd4973d68f3d91dddfdb8bc5b2db70 ) + , revMkExt3 ( 0x9896a6aa7e71659af7c11a42e6c95a361225befaf1613e7253c5224165ecaf5 , 0x1f40923d58dbe5ee7b9bd58ff493cd5141b0fd75da57434c41aeb3ddfcbe3c37, 0x26afdb536cae8d809d0609a51c463d80198d2530e32053c5e888301953dfe670 ) ) -------------------------------------------------------------------------------- diff --git a/reference/src/Skyscraper/Permutation.hs b/reference/src/Skyscraper/Permutation.hs index 5a95a97..a4c3649 100644 --- a/reference/src/Skyscraper/Permutation.hs +++ b/reference/src/Skyscraper/Permutation.hs @@ -1,5 +1,5 @@ -{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE BangPatterns, DataKinds, TypeFamilies, ScopedTypeVariables, TypeApplications #-} module Skyscraper.Permutation where -------------------------------------------------------------------------------- @@ -7,72 +7,135 @@ module Skyscraper.Permutation where import Data.Bits import Data.Word +import Data.Kind +import Data.Proxy + import BN254 import Skyscraper.RoundConst +import Skyscraper.FieldExt -------------------------------------------------------------------------------- -type F2 = Ext2 F -type F3 = Ext3 F +data SkyWidth + = Sky2 -- ^ 2 field elements + | Sky4 -- ^ degree 2 field extension (= 4 field elements) + | Sky6 -- ^ degree 3 field extension (= 6 field elements) + deriving (Eq,Show) -------------------------------------------------------------------------------- --- | @F[X] / (X^2 + 5)@ -data Ext2 a - = MkExt2 !a !a - deriving (Eq,Show) +type Digest = F -negF2 :: F2 -> F2 -negF2 (MkExt2 a0 a1) = MkExt2 (negate a0) (negate a1) +type State sky = (FieldExt sky, FieldExt sky) -addF2 :: F2 -> F2 -> F2 -addF2 (MkExt2 a0 a1) (MkExt2 b0 b1) = MkExt2 (a0+b0) (a1+b1) +stateFromList :: forall (sky :: SkyWidth). SkyscraperImpl sky => Proxy sky -> [F] -> State sky +stateFromList pxy xs = case splitAt (dimension $ Proxy @(FieldExt sky)) xs of + (us,vs) -> (extFromList us, extFromList vs) -mulF2 :: F2 -> F2 -> F2 -mulF2 (MkExt2 a0 a1) (MkExt2 b0 b1) = MkExt2 c0 c1 where - c0 = a0*b0 - 5*a1*b1 - c1 = a0*b1 + a1*b0 +stateToList :: SkyscraperImpl sky => Proxy sky -> State sky -> [F] +stateToList pxy (x,y) = extToList x ++ extToList y -sclF2 :: F -> F2 -> F2 -sclF2 s (MkExt2 a0 a1) = MkExt2 (s*a0) (s*a1) +extractDigest :: SkyscraperImpl sky => Proxy sky -> State sky -> F +extractDigest pxy = head . stateToList pxy -instance Num F2 where - fromInteger x = MkExt2 (fromInteger x) 0 - negate = negF2 - (+) = addF2 - (*) = mulF2 - signum = error "Num/F2/signum" - abs = error "Num/F2/abs" +initialState :: SkyscraperImpl sky => Proxy sky -> State sky +initialState pxy = stateFromList pxy $ replicate (rate pxy) 0 ++ [capacityIV pxy] ----------------------------------------- +-------------------------------------------------------------------------------- --- | @F[X] / (X^3 + 3)@ -data Ext3 a - = MkExt3 !a !a !a - deriving (Eq,Show) +class FieldExtension (FieldExt sky) => SkyscraperImpl (sky :: SkyWidth) where -negF3 :: F3 -> F3 -negF3 (MkExt3 a0 a1 a2) = MkExt3 (negate a0) (negate a1) (negate a2) + type FieldExt sky :: Type -addF3 :: F3 -> F3 -> F3 -addF3 (MkExt3 a0 a1 a2) (MkExt3 b0 b1 b2) = MkExt3 (a0+b0) (a1+b1) (a2+b2) + stateWidth :: Proxy sky -> Int + rate :: Proxy sky -> Int + capacity :: Proxy sky -> Int + capacityIV :: Proxy sky -> F -mulF3 :: F3 -> F3 -> F3 -mulF3 (MkExt3 a0 a1 a2) (MkExt3 b0 b1 b2) = MkExt3 c0 c1 c2 where - c0 = a0*b0 - 3*a2*b1 - 3*a1*b2 - c1 = a1*b0 + a0*b1 - 3*a2*b2 - c2 = a2*b0 + a1*b1 + a0*b2 + barsOnly :: Proxy sky -> FieldExt sky -> FieldExt sky + roundConst :: Proxy sky -> Int -> FieldExt sky -sclF3 :: F -> F3 -> F3 -sclF3 s (MkExt3 a0 a1 a2) = MkExt3 (s*a0) (s*a1) (s*a2) + capacity pxy = 1 + rate pxy = stateWidth pxy - capacity pxy + stateWidth pxy = capacity pxy + rate pxy -instance Num F3 where - fromInteger x = MkExt3 (fromInteger x) 0 0 - negate = negF3 - (+) = addF3 - (*) = mulF3 - signum = error "Num/F3/signum" - abs = error "Num/F3/abs" +-------------------------------------------------------------------------------- + +bars :: SkyscraperImpl sky => Proxy sky -> FieldExt sky -> State sky -> State sky +bars pxy rc (l,r) = (l',r') where + l' = r + barsOnly pxy l + rc + r' = l + +sq :: SkyscraperImpl sky => Proxy sky -> FieldExt sky -> State sky -> State sky +sq pxy rc (l,r) = (l',r') where + l' = r + scale invMontMultiplier (l*l) + rc + r' = l + +permute :: SkyscraperImpl sky => Proxy sky -> State sky -> State sky +permute pxy + = sq pxy (roundConst pxy 9) + . sq pxy (roundConst pxy 8) + . bars pxy (roundConst pxy 7) + . bars pxy (roundConst pxy 6) + . sq pxy (roundConst pxy 5) + . sq pxy (roundConst pxy 4) + . bars pxy (roundConst pxy 3) + . bars pxy (roundConst pxy 2) + . sq pxy (roundConst pxy 1) + . sq pxy (roundConst pxy 0) + +-------------------------------------------------------------------------------- + +compress :: SkyscraperImpl sky => Proxy sky -> Digest -> Digest -> Digest +compress pxy x y = extractDigest pxy $ permute pxy $ input where + input = stateFromList pxy $ [x,y] ++ replicate (stateWidth pxy - 2) 0 + +spongeWithPad :: SkyscraperImpl sky => Proxy sky -> [F] -> Digest +spongeWithPad pxy xs = spongeNoPad pxy (xs ++ [1]) + +spongeNoPad :: forall (sky :: SkyWidth). SkyscraperImpl sky => Proxy sky -> [F] -> Digest +spongeNoPad _ [] = error "spongeNoPad: empty input" +spongeNoPad pxy list = extractDigest pxy $ go (initialState pxy) list where + go :: State sky -> [F] -> State sky + go state [] = state + go state what = case splitAt (rate pxy) what of + (this,rest) -> go (permute pxy (mixDataWithState pxy this state)) rest + +mixDataWithState :: SkyscraperImpl sky => Proxy sky -> [F] -> State sky -> State sky +mixDataWithState pxy input state + | m == 0 = error "mixDataWithState: input is empty" + | m > r = error "mixDataWithState: input longer than the rate" + | otherwise = stateFromList pxy $ mix $ stateToList pxy state + where + r = rate pxy + m = length input + mix old = zipWith (+) old (input ++ repeat 0) + +-------------------------------------------------------------------------------- + +instance SkyscraperImpl Sky2 where + type FieldExt Sky2 = F + stateWidth _ = 2 + capacity _ = 1 + capacityIV _ = 1 + barsOnly _ = barsOnly1 + roundConst _ = rc1 + +instance SkyscraperImpl Sky4 where + type FieldExt Sky4 = F2 + stateWidth _ = 4 + capacity _ = 1 + capacityIV _ = 2 + barsOnly _ = barsOnly2 + roundConst _ = rc2 + +instance SkyscraperImpl Sky6 where + type FieldExt Sky6 = F3 + stateWidth _ = 6 + capacity _ = 1 + capacityIV _ = 3 + barsOnly _ = barsOnly3 + roundConst _ = rc3 -------------------------------------------------------------------------------- @@ -149,40 +212,6 @@ barsOnly3 (MkExt3 inp0 inp1 inp2) = MkExt3 out0 out1 out2 where -------------------------------------------------------------------------------- -bars1 :: F -> (F,F) -> (F,F) -bars1 rc (l,r) = (l',r') where - l' = r + barsOnly1 l + rc - r' = l - -bars2 :: F2 -> (F2,F2) -> (F2,F2) -bars2 rc (l,r) = (l',r') where - l' = r + barsOnly2 l + rc - r' = l - -bars3 :: F3 -> (F3,F3) -> (F3,F3) -bars3 rc (l,r) = (l',r') where - l' = r + barsOnly3 l + rc - r' = l - --------------------------------------------------------------------------------- - -sq1 :: F -> (F,F) -> (F,F) -sq1 rc (l,r) = (l',r') where - l' = r + invMontMultiplier*l*l + rc - r' = l - -sq2 :: F2 -> (F2,F2) -> (F2,F2) -sq2 rc (l,r) = (l',r') where - l' = r + sclF2 invMontMultiplier (l*l) + rc - r' = l - -sq3 :: F3 -> (F3,F3) -> (F3,F3) -sq3 rc (l,r) = (l',r') where - l' = r + sclF3 invMontMultiplier (l*l) + rc - r' = l - --------------------------------------------------------------------------------- - lkpKst :: Int -> F lkpKst k = roundConstF !! k @@ -211,43 +240,4 @@ rc3 k -------------------------------------------------------------------------------- -perm1 :: (F,F) -> (F,F) -perm1 - = sq1 (rc1 9) - . sq1 (rc1 8) - . bars1 (rc1 7) - . bars1 (rc1 6) - . sq1 (rc1 5) - . sq1 (rc1 4) - . bars1 (rc1 3) - . bars1 (rc1 2) - . sq1 (rc1 1) - . sq1 (rc1 0) -perm2 :: (F2,F2) -> (F2,F2) -perm2 - = sq2 (rc2 9) - . sq2 (rc2 8) - . bars2 (rc2 7) - . bars2 (rc2 6) - . sq2 (rc2 5) - . sq2 (rc2 4) - . bars2 (rc2 3) - . bars2 (rc2 2) - . sq2 (rc2 1) - . sq2 (rc2 0) - -perm3 :: (F3,F3) -> (F3,F3) -perm3 - = sq3 (rc3 9) - . sq3 (rc3 8) - . bars3 (rc3 7) - . bars3 (rc3 6) - . sq3 (rc3 5) - . sq3 (rc3 4) - . bars3 (rc3 3) - . bars3 (rc3 2) - . sq3 (rc3 1) - . sq3 (rc3 0) - ---------------------------------------------------------------------------------