From 055942a482b0c87fbe8b1c429667e5ad710f03fe Mon Sep 17 00:00:00 2001 From: rymnc <43716372+rymnc@users.noreply.github.com> Date: Sat, 13 Apr 2024 01:45:40 +0300 Subject: [PATCH] fix: subtree strategy --- src/BinaryIMTMemory.sol | 18 +++++------------- src/RlnBase.sol | 12 ++++++++---- test/Rln.t.sol | 20 ++++++++++++++++---- 3 files changed, 29 insertions(+), 21 deletions(-) diff --git a/src/BinaryIMTMemory.sol b/src/BinaryIMTMemory.sol index 6fae3cf..1c9c2cc 100644 --- a/src/BinaryIMTMemory.sol +++ b/src/BinaryIMTMemory.sol @@ -89,25 +89,17 @@ library BinaryIMTMemory { revert("IncrementalBinaryTree: defaultZero bad index"); } - /// @dev Computes the root of the tree given the leaves. - /// @param self: Tree data. - /// @param leaves: Leaves in the tree - function calcRoot( - BinaryIMTMemoryData memory self, - uint256 depth, - uint256[] memory leaves - ) - public - pure - returns (uint256) - { + function calcSubtreeRoot(BinaryIMTMemoryData memory self, uint256[] memory leaves) public pure returns (uint256) { + // instead of the depth being 0..19, we go from 10..19. + // this way, we preserve the subtree root + uint8 depth = 10; uint256[2][] memory lastSubtrees = new uint256[2][](depth); for (uint256 j = 0; j < leaves.length;) { uint256 index = self.numberOfLeaves; uint256 hash = leaves[j]; for (uint8 i = 0; i < depth;) { if (index & 1 == 0) { - lastSubtrees[i] = [hash, defaultZero(i)]; + lastSubtrees[i] = [hash, defaultZero(i + depth)]; } else { lastSubtrees[i][1] = hash; } diff --git a/src/RlnBase.sol b/src/RlnBase.sol index 6b1e38c..4e74970 100644 --- a/src/RlnBase.sol +++ b/src/RlnBase.sol @@ -308,18 +308,22 @@ abstract contract RlnBase { function root(uint256 shardIndex) public view returns (uint256) { BinaryIMTMemoryData memory imtData; uint256 idCommitmentIndex = leavesSet[shardIndex]; - uint256[] memory calcLeaves = new uint256[](idCommitmentIndex); + uint256[] memory calcLeaves = new uint256[](1024); + + if (idCommitmentIndex == 0) { + return BinaryIMTMemory.calcSubtreeRoot(imtData, calcLeaves); + } for (uint256 i = 0; i < idCommitmentIndex; i++) { uint256 idCommitment = leavesIndex[shardIndex][i]; uint256 userMessageLimit = userMessageLimits[idCommitment]; calcLeaves[i] = PoseidonT3.hash([idCommitment, userMessageLimit]); } - return BinaryIMTMemory.calcRoot(imtData, 10, calcLeaves); + return BinaryIMTMemory.calcSubtreeRoot(imtData, calcLeaves); } - function full_root(uint256[] calldata roots) public pure returns (uint256) { + function fullRoot(uint256[] calldata roots) public pure returns (uint256) { BinaryIMTMemoryData memory imtData; - return BinaryIMTMemory.calcRoot(imtData, 10, roots); + return BinaryIMTMemory.calcSubtreeRoot(imtData, roots); } function getLeafAtShard(uint256 shardIndex, uint256 index) public view returns (uint256) { diff --git a/test/Rln.t.sol b/test/Rln.t.sol index e3c2c48..ad64a53 100644 --- a/test/Rln.t.sol +++ b/test/Rln.t.sol @@ -263,6 +263,11 @@ contract RlnTest is Test { 18_217_334_211_520_937_958_971_536_517_166_530_749_184_547_628_672_204_353_760_850_739_130_586_503_124; vm.pauseGasMetering(); + // before we insert into the tree, the kats must pass + assertEq( + rln.root(0), + 12_413_880_268_183_407_374_852_357_075_976_609_371_175_688_755_676_981_206_018_884_971_008_854_919_922 + ); for (uint256 i = 0; i < idCommitments.length; i++) { // default 1 message limit rln.register{ value: MEMBERSHIP_DEPOSIT }(idCommitments[i], 1); @@ -271,7 +276,14 @@ contract RlnTest is Test { assertEq( rln.root(0), - 1_009_702_141_963_982_084_971_194_921_450_594_432_549_814_353_062_162_436_132_398_913_014_207_593_991 + 2_435_591_050_584_562_785_273_516_085_683_091_349_247_653_357_522_297_485_884_387_709_399_190_797_862 + ); + + uint256[] memory roots = new uint256[](1024); + roots[0] = 2_435_591_050_584_562_785_273_516_085_683_091_349_247_653_357_522_297_485_884_387_709_399_190_797_862; + assertEq( + rln.fullRoot(roots), + 3_881_711_806_223_377_894_370_509_479_419_308_415_025_371_832_767_296_831_624_691_168_726_496_442_721 ); } @@ -291,15 +303,15 @@ contract RlnTest is Test { vm.resumeGasMetering(); assertEq( rln.root(0), - 15_191_547_950_571_875_204_938_158_154_881_229_414_576_184_216_597_114_679_496_514_970_950_611_126_017 + 10_424_638_537_885_815_866_594_133_227_550_095_274_095_476_771_249_698_162_779_016_221_356_823_894_456 ); } - function test__full_root() public { + function test__fullRoot() public { uint256[] memory roots = new uint256[](1024); assertEq( // repeat 0 1024 times - rln.full_root(roots), + rln.fullRoot(roots), 12_413_880_268_183_407_374_852_357_075_976_609_371_175_688_755_676_981_206_018_884_971_008_854_919_922 ); }