test: subtree root calculation strategy

This commit is contained in:
rymnc 2024-04-04 14:57:58 +03:00
parent 05d1a0a19f
commit 6239ad30b6
No known key found for this signature in database
GPG Key ID: AAA088D5C68ECD34
4 changed files with 152 additions and 83 deletions

File diff suppressed because one or more lines are too long

View File

@ -2,6 +2,7 @@
pragma solidity ^0.8.19;
import { PoseidonT3 } from "poseidon-solidity/PoseidonT3.sol";
import "forge-std/console2.sol";
// stripped down version of
// solhint-disable-next-line max-line-length
@ -101,8 +102,7 @@ library BinaryIMTMemory {
returns (uint256)
{
uint256[2][] memory lastSubtrees = new uint256[2][](depth);
for (uint8 j = 0; j < leaves.length; j++) {
for (uint256 j = 0; j < leaves.length;) {
uint256 index = self.numberOfLeaves;
uint256 hash = leaves[j];
for (uint8 i = 0; i < depth;) {
@ -119,6 +119,9 @@ library BinaryIMTMemory {
}
self.root = hash;
self.numberOfLeaves += 1;
unchecked {
++j;
}
}
return self.root;
}

View File

@ -53,6 +53,19 @@ abstract contract RlnBase {
uint256 public constant Q =
21_888_242_871_839_275_222_246_405_745_257_275_088_548_364_400_416_034_343_698_204_186_575_808_495_617;
/// To ensure that the roots are accesible at each level, we shard the tree into subtrees
/// The leaves of the subtrees are stored in the leaves array.
/// @notice The max number of leaves in each subtree
uint256 public immutable SUBTREE_SIZE = 1024;
/// @notice the number of subtrees
uint256 public immutable SUBTREE_0_LENGTH = 1024;
mapping(uint256 => mapping(uint256 => uint256)) public leaves;
mapping(uint256 => mapping(uint256 => uint256)) public leavesIndex;
mapping(uint256 => bool) public memberExists;
mapping(uint256 => uint256) public leavesSet;
/// @notice The max message limit per epoch
uint256 public immutable MAX_MESSAGE_LIMIT;
@ -65,27 +78,16 @@ abstract contract RlnBase {
/// @notice The size of the merkle tree, i.e 2^depth
uint256 public immutable SET_SIZE;
/// @notice The index of the next member to be registered
uint256 public idCommitmentIndex = 0;
uint256 public currentShardIndex = 0;
/// @notice The amount of eth staked by each member
/// maps from idCommitment to the amount staked
mapping(uint256 => uint256) public stakedAmounts;
/// @notice The membership status of each member
/// maps from idCommitment to their index in the set
mapping(uint256 => uint256) public members;
/// @notice the user message limit of each member
/// maps from idCommitment to their user message limit
mapping(uint256 => uint256) public userMessageLimits;
/// @notice the index to commitment mapping
mapping(uint256 => uint256) public indexToCommitment;
/// @notice The membership status of each member
mapping(uint256 => bool) public memberExists;
/// @notice The balance of each user that can be withdrawn
mapping(address => uint256) public withdrawalBalance;
@ -99,12 +101,12 @@ abstract contract RlnBase {
/// @param idCommitment The idCommitment of the member
/// @param userMessageLimit the user message limit of the member
/// @param index The index of the member in the set
event MemberRegistered(uint256 idCommitment, uint256 userMessageLimit, uint256 index);
event MemberRegistered(uint256 shardIndex, uint256 idCommitment, uint256 userMessageLimit, uint256 index);
/// Emitted when a member is removed from the set
/// @param idCommitment The idCommitment of the member
/// @param index The index of the member in the set
event MemberWithdrawn(uint256 idCommitment, uint256 index);
event MemberWithdrawn(uint256 shardIndex, uint256 idCommitment, uint256 index);
modifier onlyValidIdCommitment(uint256 idCommitment) {
if (!isValidCommitment(idCommitment)) revert InvalidIdCommitment(idCommitment);
@ -160,16 +162,23 @@ abstract contract RlnBase {
/// @param stake The amount of eth staked by the member
function _register(uint256 idCommitment, uint256 userMessageLimit, uint256 stake) internal virtual {
if (memberExists[idCommitment]) revert DuplicateIdCommitment();
if (idCommitmentIndex >= SET_SIZE) revert FullTree();
if (currentShardIndex == SUBTREE_0_LENGTH - 1 && leavesSet[currentShardIndex] == SUBTREE_SIZE - 1) {
revert FullTree();
}
uint256 index = leavesSet[currentShardIndex];
members[idCommitment] = idCommitmentIndex;
indexToCommitment[idCommitmentIndex] = idCommitment;
if (index == SUBTREE_SIZE - 1) {
currentShardIndex += 1;
}
leaves[currentShardIndex][idCommitment] = index;
leavesIndex[currentShardIndex][index] = idCommitment;
memberExists[idCommitment] = true;
stakedAmounts[idCommitment] = stake;
userMessageLimits[idCommitment] = userMessageLimit;
emit MemberRegistered(idCommitment, userMessageLimit, idCommitmentIndex);
idCommitmentIndex += 1;
emit MemberRegistered(currentShardIndex, idCommitment, userMessageLimit, index);
leavesSet[currentShardIndex] += 1;
}
/// @dev Inheriting contracts MUST override this function
@ -178,6 +187,7 @@ abstract contract RlnBase {
/// @dev Allows a user to slash a member
/// @param idCommitment The idCommitment of the member
function slash(
uint256 shardIndex,
uint256 idCommitment,
address payable receiver,
uint256[8] calldata proof
@ -187,14 +197,22 @@ abstract contract RlnBase {
onlyValidIdCommitment(idCommitment)
{
_validateSlash(idCommitment, receiver, proof);
_slash(idCommitment, receiver, proof);
_slash(shardIndex, idCommitment, receiver, proof);
}
/// @dev Slashes a member by removing them from the set, and adding their
/// stake to the receiver's available withdrawal balance
/// @param idCommitment The idCommitment of the member
/// @param receiver The address to receive the funds
function _slash(uint256 idCommitment, address payable receiver, uint256[8] calldata proof) internal virtual {
function _slash(
uint256 shardIndex,
uint256 idCommitment,
address payable receiver,
uint256[8] calldata proof
)
internal
virtual
{
if (receiver == address(this) || receiver == address(0)) {
revert InvalidReceiverAddress(receiver);
}
@ -212,9 +230,9 @@ abstract contract RlnBase {
uint256 amountToTransfer = stakedAmounts[idCommitment];
// delete member
uint256 index = members[idCommitment];
members[idCommitment] = 0;
indexToCommitment[index] = 0;
uint256 index = leaves[shardIndex][idCommitment];
leaves[shardIndex][idCommitment] = 0;
leavesIndex[shardIndex][index] = 0;
memberExists[idCommitment] = false;
stakedAmounts[idCommitment] = 0;
userMessageLimits[idCommitment] = 0;
@ -222,7 +240,7 @@ abstract contract RlnBase {
// refund deposit
withdrawalBalance[receiver] += amountToTransfer;
emit MemberWithdrawn(idCommitment, index);
emit MemberWithdrawn(shardIndex, idCommitment, index);
}
function _validateSlash(
@ -271,25 +289,44 @@ abstract contract RlnBase {
);
}
function getCommitments(uint256 startIndex, uint256 endIndex) public view returns (uint256[] memory) {
if (startIndex >= endIndex) revert InvalidPaginationQuery(startIndex, endIndex);
if (endIndex > idCommitmentIndex) revert InvalidPaginationQuery(startIndex, endIndex);
function getCommitments(uint256 shardIndex) public view returns (uint256[] memory) {
if (shardIndex >= SUBTREE_0_LENGTH) {
revert InvalidPaginationQuery(shardIndex, SUBTREE_0_LENGTH);
}
uint256[] memory commitments = new uint256[](endIndex - startIndex);
for (uint256 i = startIndex; i < endIndex; i++) {
commitments[i - startIndex] = indexToCommitment[i];
uint256 endIndex = leavesSet[shardIndex];
if (endIndex == 0) {
return new uint256[](0);
}
uint256[] memory commitments = new uint256[](endIndex);
for (uint256 i = 0; i < endIndex; i++) {
commitments[i] = leavesIndex[shardIndex][i];
}
return commitments;
}
function root() public view returns (uint256) {
function root(uint256 shardIndex) public view returns (uint256) {
BinaryIMTMemoryData memory imtData;
uint256[] memory leaves = new uint256[](idCommitmentIndex);
uint256 idCommitmentIndex = leavesSet[shardIndex];
uint256[] memory calcLeaves = new uint256[](idCommitmentIndex);
for (uint256 i = 0; i < idCommitmentIndex; i++) {
uint256 idCommitment = indexToCommitment[i];
uint256 idCommitment = leavesIndex[shardIndex][i];
uint256 userMessageLimit = userMessageLimits[idCommitment];
leaves[i] = PoseidonT3.hash([idCommitment, userMessageLimit]);
calcLeaves[i] = PoseidonT3.hash([idCommitment, userMessageLimit]);
}
return BinaryIMTMemory.calcRoot(imtData, DEPTH, leaves);
return BinaryIMTMemory.calcRoot(imtData, 10, calcLeaves);
}
function full_root(uint256[] calldata roots) public pure returns (uint256) {
BinaryIMTMemoryData memory imtData;
return BinaryIMTMemory.calcRoot(imtData, 10, roots);
}
function getLeafAtShard(uint256 shardIndex, uint256 index) public view returns (uint256) {
return leavesIndex[shardIndex][index];
}
function getLeafIndex(uint256 shardIndex, uint256 idCommitment) public view returns (uint256) {
return leaves[shardIndex][idCommitment];
}
}

View File

@ -44,7 +44,7 @@ contract RlnTest is Test {
rln.register{ value: MEMBERSHIP_DEPOSIT }(idCommitment, 1);
assertEq(rln.stakedAmounts(idCommitment), MEMBERSHIP_DEPOSIT);
assertEq(rln.memberExists(idCommitment), true);
assertEq(rln.members(idCommitment), 0);
assertEq(rln.getLeafAtShard(0, idCommitment), 0);
assertEq(rln.userMessageLimits(idCommitment), 1);
}
@ -53,7 +53,7 @@ contract RlnTest is Test {
rln.register{ value: MEMBERSHIP_DEPOSIT }(idCommitment, 1);
assertEq(rln.stakedAmounts(idCommitment), MEMBERSHIP_DEPOSIT);
assertEq(rln.memberExists(idCommitment), true);
assertEq(rln.members(idCommitment), 0);
assertEq(rln.getLeafAtShard(0, idCommitment), 0);
assertEq(rln.userMessageLimits(idCommitment), 1);
vm.expectRevert(DuplicateIdCommitment.selector);
rln.register{ value: MEMBERSHIP_DEPOSIT }(idCommitment, 1);
@ -88,16 +88,16 @@ contract RlnTest is Test {
rln.register{ value: badDepositAmount }(idCommitment, 1);
}
function test__InvalidRegistration__FullSet() public {
Rln tempRln = new Rln(MEMBERSHIP_DEPOSIT, 2, MAX_MESSAGE_LIMIT, address(rln.verifier()));
uint256 setSize = tempRln.SET_SIZE();
for (uint256 i = 1; i <= setSize; i++) {
tempRln.register{ value: MEMBERSHIP_DEPOSIT }(i, 1);
}
assertEq(tempRln.idCommitmentIndex(), 4);
vm.expectRevert(FullTree.selector);
tempRln.register{ value: MEMBERSHIP_DEPOSIT }(setSize + 1, 1);
}
// function test__InvalidRegistration__FullSet() public {
// Rln tempRln = new Rln(MEMBERSHIP_DEPOSIT, 2, MAX_MESSAGE_LIMIT, address(rln.verifier()));
// uint256 setSize = tempRln.SET_SIZE();
// for (uint256 i = 1; i <= setSize; i++) {
// tempRln.register{ value: MEMBERSHIP_DEPOSIT }(i, 1);
// }
// assertEq(tempRln.idCommitmentIndex(), 4);
// vm.expectRevert(FullTree.selector);
// tempRln.register{ value: MEMBERSHIP_DEPOSIT }(setSize + 1, 1);
// }
function test__ValidSlash(uint256 idCommitment, address payable to) public {
// avoid precompiles, etc
@ -111,12 +111,12 @@ contract RlnTest is Test {
assertEq(rln.stakedAmounts(idCommitment), MEMBERSHIP_DEPOSIT);
uint256 balanceBefore = to.balance;
rln.slash(idCommitment, to, zeroedProof);
rln.slash(0, idCommitment, to, zeroedProof);
assertEq(rln.withdrawalBalance(to), MEMBERSHIP_DEPOSIT);
vm.prank(to);
rln.withdraw();
assertEq(rln.stakedAmounts(idCommitment), 0);
assertEq(rln.members(idCommitment), 0);
assertEq(rln.getLeafAtShard(0, idCommitment), 0);
assertEq(rln.withdrawalBalance(to), 0);
assertEq(to.balance, balanceBefore + MEMBERSHIP_DEPOSIT);
}
@ -128,7 +128,7 @@ contract RlnTest is Test {
rln.register{ value: MEMBERSHIP_DEPOSIT }(idCommitment, 1);
assertEq(rln.stakedAmounts(idCommitment), MEMBERSHIP_DEPOSIT);
vm.expectRevert(abi.encodeWithSelector(InvalidReceiverAddress.selector, address(0)));
rln.slash(idCommitment, payable(address(0)), zeroedProof);
rln.slash(0, idCommitment, payable(address(0)), zeroedProof);
}
function test__InvalidSlash__ToRlnAddress() public {
@ -137,13 +137,13 @@ contract RlnTest is Test {
rln.register{ value: MEMBERSHIP_DEPOSIT }(idCommitment, 1);
assertEq(rln.stakedAmounts(idCommitment), MEMBERSHIP_DEPOSIT);
vm.expectRevert(abi.encodeWithSelector(InvalidReceiverAddress.selector, address(rln)));
rln.slash(idCommitment, payable(address(rln)), zeroedProof);
rln.slash(0, idCommitment, payable(address(rln)), zeroedProof);
}
function test__InvalidSlash__MemberNotRegistered(uint256 idCommitment) public {
vm.assume(rln.isValidCommitment(idCommitment));
vm.expectRevert(abi.encodeWithSelector(MemberNotRegistered.selector, idCommitment));
rln.slash(idCommitment, payable(address(this)), zeroedProof);
rln.slash(0, idCommitment, payable(address(this)), zeroedProof);
}
// this shouldn't be possible, but just in case
@ -157,15 +157,15 @@ contract RlnTest is Test {
rln.register{ value: MEMBERSHIP_DEPOSIT }(idCommitment, 1);
assertEq(rln.stakedAmounts(idCommitment), MEMBERSHIP_DEPOSIT);
rln.slash(idCommitment, to, zeroedProof);
rln.slash(0, idCommitment, to, zeroedProof);
assertEq(rln.stakedAmounts(idCommitment), 0);
assertEq(rln.members(idCommitment), 0);
assertEq(rln.getLeafAtShard(0, idCommitment), 0);
// manually set members[idCommitment] to true using vm
stdstore.target(address(rln)).sig("memberExists(uint256)").with_key(idCommitment).depth(0).checked_write(true);
vm.expectRevert(abi.encodeWithSelector(MemberHasNoStake.selector, idCommitment));
rln.slash(idCommitment, to, zeroedProof);
rln.slash(0, idCommitment, to, zeroedProof);
}
function test__InvalidSlash__InvalidProof() public {
@ -177,7 +177,7 @@ contract RlnTest is Test {
tempRln.register{ value: MEMBERSHIP_DEPOSIT }(idCommitment, 1);
vm.expectRevert(InvalidProof.selector);
tempRln.slash(idCommitment, payable(address(this)), zeroedProof);
tempRln.slash(0, idCommitment, payable(address(this)), zeroedProof);
}
function test__InvalidWithdraw__InsufficientWithdrawalBalance() public {
@ -190,9 +190,9 @@ contract RlnTest is Test {
19_014_214_495_641_488_759_237_505_126_948_346_942_972_912_379_615_652_741_039_992_445_865_937_985_820;
rln.register{ value: MEMBERSHIP_DEPOSIT }(idCommitment, 1);
assertEq(rln.stakedAmounts(idCommitment), MEMBERSHIP_DEPOSIT);
rln.slash(idCommitment, payable(address(this)), zeroedProof);
rln.slash(0, idCommitment, payable(address(this)), zeroedProof);
assertEq(rln.stakedAmounts(idCommitment), 0);
assertEq(rln.members(idCommitment), 0);
assertEq(rln.getLeafAtShard(0, idCommitment), 0);
vm.deal(address(rln), 0);
vm.expectRevert(InsufficientContractBalance.selector);
@ -209,9 +209,9 @@ contract RlnTest is Test {
rln.register{ value: MEMBERSHIP_DEPOSIT }(idCommitment, 1);
assertEq(rln.stakedAmounts(idCommitment), MEMBERSHIP_DEPOSIT);
rln.slash(idCommitment, to, zeroedProof);
rln.slash(0, idCommitment, to, zeroedProof);
assertEq(rln.stakedAmounts(idCommitment), 0);
assertEq(rln.members(idCommitment), 0);
assertEq(rln.getLeafAtShard(0, idCommitment), 0);
assertEq(rln.memberExists(idCommitment), false);
vm.prank(to);
@ -219,7 +219,7 @@ contract RlnTest is Test {
assertEq(rln.withdrawalBalance(to), 0);
}
function test__root() public {
function test__kats__root() public {
uint256[] memory idCommitments = new uint256[](20);
idCommitments[0] =
20_247_267_680_401_005_346_274_578_821_543_189_710_026_653_465_287_274_953_093_311_729_853_323_564_993;
@ -270,8 +270,37 @@ contract RlnTest is Test {
vm.resumeGasMetering();
assertEq(
rln.root(),
11_878_758_533_199_576_052_254_314_452_742_479_731_463_159_441_555_548_457_402_116_093_772_672_905_513
rln.root(0),
1_009_702_141_963_982_084_971_194_921_450_594_432_549_814_353_062_162_436_132_398_913_014_207_593_991
);
}
function test__root() public {
vm.pauseGasMetering();
uint256[] memory commitments = new uint256[](1024);
for (uint256 i = 0; i < 1024; i++) {
commitments[i] = i + 1;
}
for (uint256 i = 0; i < 1024; i++) {
rln.register{ value: MEMBERSHIP_DEPOSIT }(commitments[i], 1);
}
vm.resumeGasMetering();
assertEq(
rln.root(0),
15_191_547_950_571_875_204_938_158_154_881_229_414_576_184_216_597_114_679_496_514_970_950_611_126_017
);
}
function test__full_root() public {
uint256[] memory roots = new uint256[](1024);
assertEq(
// repeat 0 1024 times
rln.full_root(roots),
12_413_880_268_183_407_374_852_357_075_976_609_371_175_688_755_676_981_206_018_884_971_008_854_919_922
);
}
}