From 55015008e791a25f8d6153f884b846dd9ad7a3bd Mon Sep 17 00:00:00 2001 From: Balazs Komuves Date: Tue, 28 Nov 2023 12:32:36 +0100 Subject: [PATCH] the circuit seems to work --- circuit/extract_bits.circom | 2 + circuit/log2.circom | 9 +- circuit/merkle.circom | 87 +++++++++++++++++++ circuit/misc.circom | 2 +- circuit/poseidon2_merkle.circom | 46 ---------- circuit/poseidon2_perm.circom | 15 ++++ circuit/poseidon2_sponge.circom | 9 +- circuit/sample_cells.circom | 129 ++++++++++++++++------------ circuit/single_cell.circom | 136 ++++++++++-------------------- circuit/slot_main.circom | 3 +- circuit/test_prove.sh | 35 ++++++++ circuit/test_slotmain.sh | 14 +-- reference/haskell/cli/testMain.hs | 14 +-- reference/haskell/src/DataSet.hs | 27 ++++-- reference/haskell/src/Sampling.hs | 13 ++- test/Circuit/CeilingLog2.hs | 8 +- test/Main.hs | 2 +- 17 files changed, 327 insertions(+), 224 deletions(-) create mode 100644 circuit/merkle.circom delete mode 100644 circuit/poseidon2_merkle.circom create mode 100755 circuit/test_prove.sh diff --git a/circuit/extract_bits.circom b/circuit/extract_bits.circom index 27f4b0c..6b28e6e 100644 --- a/circuit/extract_bits.circom +++ b/circuit/extract_bits.circom @@ -11,6 +11,8 @@ include "misc.circom"; // NOTE: this is rather nontrivial, as everything is computed modulo `r`, // so naive bit decomposition does not work (there are multiple solutions). // +// TODO: optimize this +// template ExtractLowerBits(n) { diff --git a/circuit/log2.circom b/circuit/log2.circom index 78b3f74..a0a6150 100644 --- a/circuit/log2.circom +++ b/circuit/log2.circom @@ -42,11 +42,13 @@ template Log2(n) { //------------------------------------------------------------------------------ // -// given an input `inp`, this template computes `k` such that 2^k <= inp < 2^{k+1} +// given an input `inp`, this template computes `out := k` such that 2^k <= inp < 2^{k+1} // it also returns the binary decomposition of `inp-1`, and the binary deocmpositiom // of the mask `(2^k-1)` // -// we also output a mask vector which is 1 for i=0..out-1, and 0 elsewhere +// we also output a mask vector which is 1 for i=0..k-1, and 0 elsewhere +// +// we require `k <= n`, otherwise this will fail. // template CeilingLog2(n) { @@ -54,7 +56,7 @@ template CeilingLog2(n) { signal input inp; signal output out; signal output bits[n]; - signal output mask[n]; + signal output mask[n+1]; component tb = ToBits(n); tb.inp <== inp - 1; @@ -68,6 +70,7 @@ template CeilingLog2(n) { mask[i] <== 1 - aux[i]; sum = sum + (aux[i+1] - aux[i]) * (i+1); } + mask[n] <== 0; out <== sum; } diff --git a/circuit/merkle.circom b/circuit/merkle.circom new file mode 100644 index 0000000..beab67b --- /dev/null +++ b/circuit/merkle.circom @@ -0,0 +1,87 @@ +pragma circom 2.0.0; + +include "poseidon2_perm.circom"; +include "poseidon2_hash.circom"; + +include "misc.circom"; + +//------------------------------------------------------------------------------ + +// +// reconstruct the Merkle root using a Merkle inclusion proof +// +// parameters: +// - depth: the depth of the Merkle tree = log2( numberOfLeaves ) +// +// inputs and outputs: +// - leaf: the leaf hash +// - pathBits: the linear index of the leaf, in binary decomposition (little-endian) +// - lastBits: the index of the last leaf (= nLeaves-1), in binary decomposition +// - maskBits: the bits of the the mask `2^ceilingLog2(size) - 1` +// - merklePath: the Merkle inclusion proof (required hashes, starting from the leaf and ending near the root) +// - recRoot: the reconstructod Merkle root +// +// NOTE: we don't check whether the bits are really bits, that's the +// responsability of the caller! +// + +template RootFromMerklePath( maxDepth ) { + + signal input leaf; + signal input pathBits[ maxDepth ]; // bits of the linear index + signal input lastBits[ maxDepth ]; // bits of the last linear index `= size-1` + signal input maskBits[ maxDepth+1 ]; // bit mask for `2^ceilingLog(size) - 1` + signal input merklePath[ maxDepth ]; + signal output recRoot; + + // the sequence of reconstructed hashes along the path + signal aux[ maxDepth+1 ]; + aux[0] <== leaf; + + // compute which prefixes (in big-endian) of the index is + // the same as the corresponding prefix of the last index + component eq[ maxDepth ]; + signal isLast[ maxDepth+1 ]; + isLast[ maxDepth ] <== 1; + for(var i=maxDepth-1; i>=0; i--) { + eq[i] = IsEqual(); + eq[i].A <== pathBits[i]; + eq[i].B <== lastBits[i]; + isLast[i] <== isLast[i+1] * eq[i].out; + } + + // compute the sequence of hashes + signal switch[ maxDepth ]; + component comp[ maxDepth ]; + for(var i=0; i aux[i+1]; + } + + // now we need to select the right layer from the sequence of hashes + var sum = 0; + signal prods[maxDepth]; + for(var i=0; i out; } diff --git a/circuit/poseidon2_merkle.circom b/circuit/poseidon2_merkle.circom deleted file mode 100644 index f0bf4c1..0000000 --- a/circuit/poseidon2_merkle.circom +++ /dev/null @@ -1,46 +0,0 @@ -pragma circom 2.0.0; - -include "poseidon2_perm.circom"; - -//------------------------------------------------------------------------------ -// Merkle tree built using the Poseidon2 permutation -// -// The number of leaves is `2**nlevels` -// - -template PoseidonMerkle(nlevels) { - - var nleaves = 2**nlevels; - - signal input inp[nleaves]; - signal output out_root; - - component hsh[ nleaves-1]; - signal aux[2*nleaves-1]; - - for(var k=0; k aux[b+ k ]; - } - - a = b; - u = v; - } - - aux[2*nleaves-2] ==> out_root; -} - -//------------------------------------------------------------------------------ diff --git a/circuit/poseidon2_perm.circom b/circuit/poseidon2_perm.circom index 33cdb53..c518378 100644 --- a/circuit/poseidon2_perm.circom +++ b/circuit/poseidon2_perm.circom @@ -213,5 +213,20 @@ template Compression() { perm.out[0] ==> out; } +//-------------------------------------- + +template KeyedCompression() { + signal input key; + signal input inp[2]; + signal output out; + + component perm = Permutation(); + perm.inp[0] <== inp[0]; + perm.inp[1] <== inp[1]; + perm.inp[2] <== key; + + perm.out[0] ==> out; +} + //------------------------------------------------------------------------------ diff --git a/circuit/poseidon2_sponge.circom b/circuit/poseidon2_sponge.circom index 70bb526..d7698bb 100644 --- a/circuit/poseidon2_sponge.circom +++ b/circuit/poseidon2_sponge.circom @@ -51,9 +51,14 @@ template PoseidonSponge(t, capacity, input_len, output_len) { signal state [nblocks+nout][t ]; signal sorbed[nblocks ][rate]; - + + // domain separation, capacity IV := 2^64 + 256*t + rate + var civ = 2**64 + 256*t + rate; + // log("capacity IV = ",civ); + // initialize state - for(var i=0; i hash; - // extract the lowest `log2(nCells)` bits - component md = ExtractLowerBits(log2N); + // extract the lowest `maxLog2N = 32` bits + component md = ExtractLowerBits(maxLog2N); md.inp <== hash; - md.out ==> indexBits; -} - -//------------------------------------------------------------------------------ - -// -// same as above, but returns an integer index instead of its binary decomposition. -// - -template CalculateCellIndexInteger( nCells ) { - - var log2N = CeilLog2(nCells); - assert( nCells == (1< aux[i+1]; - } - - aux[depth] ==> recRoot; -} - -//-------------------------------------- - -// -// a version of the above where the leaf index is given as an integer -// instead of a sequence of bits -// - -template MerklePathIndex( depth ) { - - signal input leaf; - signal input linearIndex; - signal input merklePath[ depth ]; - signal output recRoot; - - // decompose the linear cell index into bits (0 = left, 1 = right) - component tb = ToBits( depth ); - component path = MerklePathBits( depth ); - - tb.inp <== linearIndex; - tb.out ==> path.pathBits; - - path.leaf <== leaf; - path.merklePath <== merklePath; - path.recRoot ==> recRoot; -} - -//------------------------------------------------------------------------------ - // // calculates a single cell's hash and reconstructs the Merkle root, // checking whether it matches the given slot root // // parameters: -// - nFieldElemsPerCell: how many field elements a cell consists of -// - merkleDepth: the depth of slot subtree = log2(nCellsPerSlot) +// - nFieldElemsPerCell: how many field elements a cell consists of (= 2048/31 = 67) +// - botDepth: the depth of the per-block minitree (= 5) +// - maxDepth: the maximum depth of slot subtree (= 32) // // inputs and outputs: -// - indexBits: the linear index of the cell, within the slot subtree, in binary -// - data: the cell data (already encoded as field elements) -// - merklePath: the Merkle inclusion proof -// - slotRoot: the expected slot root +// - indexBits: the linear index of the cell, within the slot subtree, in binary +// - lastBits: the index of the last cell (size - 1), in binary (required for odd-even node key) +// - maskBits: the binary mask of the size rounded up to a power of two +// - data: the cell data (already encoded as field elements) +// - merklePath: the Merkle inclusion proof +// - slotRoot: the expected slot root // // NOTE: we don't check whether the bits are really bits, that's the -// responsability of the caller! +// responsability of the caller! // -template ProveSingleCell( nFieldElemsPerCell, merkleDepth ) { +template ProveSingleCell( nFieldElemsPerCell, botDepth, maxDepth ) { signal input slotRoot; signal input data[ nFieldElemsPerCell ]; + signal input lastBits [ maxDepth ]; + signal input indexBits [ maxDepth ]; + signal input maskBits [ maxDepth + 1 ]; + signal input merklePath[ maxDepth ]; - signal input indexBits [ merkleDepth ]; - signal input merklePath[ merkleDepth ]; + // these will reconstruct the Merkle path up to the slot root + // in two steps: first the block-level ("bottom"), then the slot-level ("middle") + component pbot = RootFromMerklePath( botDepth ); + component pmid = RootFromMerklePath( maxDepth - botDepth ); - // this will reconstruct the Merkle path up to the slot root - component path = MerklePathBits( merkleDepth ); - path.pathBits <== indexBits; - path.merklePath <== merklePath; + for(var i=0; i path.leaf; + hash.inp <== data; + hash.out ==> pbot.leaf; + pmid.leaf <== pbot.recRoot; + +log("middle bottom root check = ", pmid.recRoot == slotRoot); // check if the reconstructed root matches the actual slot root - path.recRoot === slotRoot; + pmid.recRoot === slotRoot; } diff --git a/circuit/slot_main.circom b/circuit/slot_main.circom index a2357bb..072ae14 100644 --- a/circuit/slot_main.circom +++ b/circuit/slot_main.circom @@ -1,3 +1,4 @@ pragma circom 2.0.0; include "sample_cells.circom"; -component main {public [entropy,slotRoot]} = SampleAndProveV1(1024, 5, 10); +// SampleAndProven( maxDepth, maxLog2NSlots, blockTreeDepth, nFieldElemsPerCell, nSamples ) +component main {public [entropy,dataSetRoot,slotIndex]} = SampleAndProve(16, 5, 3, 5, 10); diff --git a/circuit/test_prove.sh b/circuit/test_prove.sh new file mode 100755 index 0000000..fc371e6 --- /dev/null +++ b/circuit/test_prove.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +ORIG=`pwd` + +PTAU_DIR="/Users/bkomuves/zk/ptau/" +PTAU_FILE="${PTAU_DIR}/powersOfTau28_hez_final_20.ptau" + +NAME="slot_main" + +# --- setup --- + +cd $ORIG/build +echo "circuit-specific ceremony..." +snarkjs groth16 setup ${NAME}.r1cs ${PTAU_FILE} ${NAME}_0000.zkey +echo "some_entropy_xxx" | snarkjs zkey contribute ${NAME}_0000.zkey ${NAME}_0001.zkey --name="1st Contributor Name" +rm ${NAME}_0000.zkey +mv ${NAME}_0001.zkey ${NAME}.zkey +snarkjs zkey export verificationkey ${NAME}.zkey ${NAME}_verification_key.json + +# --- prove --- + +echo "" +echo "trying to prove... (with snarkjs)" + +cd $ORIG/build +time snarkjs groth16 prove ${NAME}.zkey ${NAME}.wtns ${NAME}_proof.json ${NAME}_public.json + +# --- verify --- + +echo "" +echo "verifyng proof..." +snarkjs groth16 verify ${NAME}_verification_key.json ${NAME}_public.json ${NAME}_proof.json + +cd $ORIG +cd $ORIG diff --git a/circuit/test_slotmain.sh b/circuit/test_slotmain.sh index f1bea07..473bb5f 100755 --- a/circuit/test_slotmain.sh +++ b/circuit/test_slotmain.sh @@ -1,21 +1,25 @@ #!/bin/bash ORIG=`pwd` +mkdir -p build cd ../reference/haskell -runghc testMain.hs || { echo "ghc failed"; exit 101; } +mkdir -p json +cabal v1-run cli/testMain.hs || { echo "ghc failed"; cd $ORIG; exit 101; } -mv input_example.json ${ORIG}/build/ -mv slot_main.circom ${ORIG} +mv json/input_example.json ${ORIG}/build/ +mv json/slot_main.circom ${ORIG} cd ${ORIG}/build NAME="slot_main" -circom ../${NAME}.circom --r1cs --wasm || { echo "circom failed"; exit 102; } +circom ../${NAME}.circom --r1cs --wasm || { echo "circom failed"; cd $ORIG; exit 102; } echo "generating witness... (WASM)" cd ${NAME}_js -node generate_witness.js ${NAME}.wasm ../input_example.json ../${NAME}_witness.wtns || { echo "witness gen failed"; exit 101; } +node generate_witness.js ${NAME}.wasm ../input_example.json ../${NAME}.wtns || { echo "witness gen failed"; cd $ORIG; exit 101; } cd .. +echo "witness generation succeeded" + cd $ORIG diff --git a/reference/haskell/cli/testMain.hs b/reference/haskell/cli/testMain.hs index d6bb504..86ee6e6 100644 --- a/reference/haskell/cli/testMain.hs +++ b/reference/haskell/cli/testMain.hs @@ -11,17 +11,21 @@ import Sampling smallDataSetCfg :: DataSetCfg smallDataSetCfg = MkDataSetCfg - { _nSlots = 5 + { _maxDepth = 16 + , _maxLog2NSlots = 5 + , _nSlots = 5 , _cellSize = 128 - , _blockSize = 4096 - , _nCells = 256 - , _nSamples = 5 + , _blockSize = 1024 + , _nCells = 64 + , _nSamples = 10 , _dataSrc = FakeData (Seed 12345) } bigDataSetCfg :: DataSetCfg bigDataSetCfg = MkDataSetCfg - { _nSlots = 13 + { _maxDepth = 32 + , _maxLog2NSlots = 8 + , _nSlots = 13 , _cellSize = 2048 , _blockSize = 65536 , _nCells = 512 diff --git a/reference/haskell/src/DataSet.hs b/reference/haskell/src/DataSet.hs index cd5c97e..844cb53 100644 --- a/reference/haskell/src/DataSet.hs +++ b/reference/haskell/src/DataSet.hs @@ -4,19 +4,24 @@ module DataSet where -------------------------------------------------------------------------------- +import Data.List import System.FilePath import Slot hiding ( MkSlotCfg(..) ) import qualified Slot as Slot +import Misc + -------------------------------------------------------------------------------- data DataSetCfg = MkDataSetCfg - { _nSlots :: Int -- ^ number of slots per dataset - , _cellSize :: Int - , _blockSize :: Int - , _nCells :: Int - , _nSamples :: Int + { _maxDepth :: Int -- ^ @nCells@ must fit into this many bits + , _maxLog2NSlots :: Int -- ^ @nSlots@ must fit into this many bits + , _nSlots :: Int -- ^ number of slots per dataset + , _cellSize :: Int -- ^ cell size in bytes + , _blockSize :: Int -- ^ slot size in bytes + , _nCells :: Int -- ^ number of cells per slot + , _nSamples :: Int -- ^ number of cells we sample in a proof , _dataSrc :: DataSource } deriving Show @@ -57,12 +62,20 @@ loadDataSetBlock dsetCfg slotIdx@(SlotIdx idx) blockidx circomMainComponent :: DataSetCfg -> FilePath -> IO () circomMainComponent dsetCfg circomFile = do - let params = show (DataSet.fieldElemsPerCell dsetCfg) - ++ ", " ++ show (DataSet._nSamples dsetCfg) + let cellsPerBlock = (DataSet._blockSize dsetCfg) `div` (DataSet._cellSize dsetCfg) + let blockDepth = ceilingLog2 (fromIntegral cellsPerBlock) + let params = intercalate ", " $ map show + [ DataSet._maxDepth dsetCfg + , DataSet._maxLog2NSlots dsetCfg + , blockDepth + , DataSet.fieldElemsPerCell dsetCfg + , DataSet._nSamples dsetCfg + ] writeFile circomFile $ unlines [ "pragma circom 2.0.0;" , "include \"sample_cells.circom\";" + , "// SampleAndProven( maxDepth, maxLog2NSlots, blockTreeDepth, nFieldElemsPerCell, nSamples ) " , "component main {public [entropy,dataSetRoot,slotIndex]} = SampleAndProve(" ++ params ++ ");" ] diff --git a/reference/haskell/src/Sampling.hs b/reference/haskell/src/Sampling.hs index eb4504d..aa9349f 100644 --- a/reference/haskell/src/Sampling.hs +++ b/reference/haskell/src/Sampling.hs @@ -36,6 +36,15 @@ sampleCellIndex cfg entropy slotRoot counter = CellIdx (fromInteger idx) where -------------------------------------------------------------------------------- +padWithZeros :: Int -> [Fr] -> [Fr] +padWithZeros n xs + | m <= n = xs ++ replicate (n-m) Fr.zero + | otherwise = error "padWithZeros: input too long" + where + m = length xs + +-------------------------------------------------------------------------------- + data CircuitInput = MkInput { _entropy :: Entropy -- ^ public input , _dataSetRoot :: Hash -- ^ public input @@ -72,11 +81,11 @@ calculateCircuitInput dataSetCfg slotIdx@(SlotIdx sidx) entropy = do , _dataSetRoot = dsetRoot , _slotIndex = sidx , _slotRoot = ourSlotRoot - , _slotProof = extractMerkleProof_ dsetTree sidx + , _slotProof = padWithZeros (_maxLog2NSlots dataSetCfg) $ extractMerkleProof_ dsetTree sidx , _slotsPerDSet = nslots , _cellsPerSlot = Slot._nCells ourSlotCfg , _cellData = cellData - , _merklePaths = merklePaths + , _merklePaths = map (padWithZeros (_maxDepth dataSetCfg)) merklePaths } -- | Export the inputs of the storage proof circuits in JSON format, diff --git a/test/Circuit/CeilingLog2.hs b/test/Circuit/CeilingLog2.hs index f2b6041..fe6463a 100644 --- a/test/Circuit/CeilingLog2.hs +++ b/test/Circuit/CeilingLog2.hs @@ -30,12 +30,12 @@ type TestCase = Integer type Output = (Int,[Bool],[Bool]) semantics :: GP -> TestCase -> Expected Output -semantics n a - | a >0 && k >= 0 && k < n = Expecting (k,bits,mask) - | otherwise = ShouldFail +semantics n a + | a > 0 && k >= 0 && k <= n = Expecting (k,bits,mask) + | otherwise = ShouldFail where k = ceilingLog2 a - mask = [ i < k | i<-[0..n-1] ] + mask = [ i < k | i<-[0..n] ] bits = [ testBit (a-1) i | i<-[0..n-1] ] -- | Smallest integer @k@ such that @2^k@ is larger or equal to @n@ diff --git a/test/Main.hs b/test/Main.hs index 8804e01..e6e4918 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -40,5 +40,5 @@ testSimple' verbosity = do -------------------------------------------------------------------------------- main = do - testSimple' Silent -- Verbose -- Silent + testSimple' Info --Silent -- Verbose -- Silent