ens-usernames/contracts/test/MerkleTree.sol

101 lines
3.2 KiB
Solidity

// SPDX-License-Identifier: MIT
pragma solidity >=0.8.25 <0.9.0;
library MerkleTree {
error ElementDoesNotExist();
function combinedHash(bytes32 first, bytes32 second) internal pure returns (bytes32) {
if (first == bytes32(0)) {
return second;
}
if (second == bytes32(0)) {
return first;
}
return keccak256(abi.encodePacked(first, second));
}
function getLayers(bytes32[] memory elements) internal pure returns (bytes32[][] memory) {
if (elements.length == 0) {
bytes32[][] memory empty = new bytes32[][](1);
empty;
return empty;
}
uint256 layerCount = 1;
uint256 remainingElements = elements.length;
while (remainingElements > 1) {
layerCount++;
remainingElements = (remainingElements + 1) / 2;
}
bytes32[][] memory layers = new bytes32[][](layerCount);
layers[0] = elements;
for (uint256 i = 1; i < layerCount; i++) {
layers[i] = getNextLayer(layers[i - 1]);
}
return layers;
}
function getNextLayer(bytes32[] memory elements) internal pure returns (bytes32[] memory) {
uint256 nextLayerLength = (elements.length + 1) / 2;
bytes32[] memory nextLayer = new bytes32[](nextLayerLength);
for (uint256 i = 0; i < elements.length; i += 2) {
bytes32 first = elements[i];
bytes32 second = (i + 1 < elements.length) ? elements[i + 1] : bytes32(0);
nextLayer[i / 2] = combinedHash(first, second);
}
return nextLayer;
}
function getRoot(bytes32[] memory elements) internal pure returns (bytes32) {
bytes32[][] memory layers = getLayers(elements);
return layers[layers.length - 1][0];
}
function getProof(bytes32[] memory elements, bytes32 element) internal pure returns (bytes32[] memory) {
uint256 index = indexOf(elements, element);
if (index == type(uint256).max) {
revert ElementDoesNotExist();
}
bytes32[][] memory layers = getLayers(elements);
uint256 proofLength = layers.length - 1;
bytes32[] memory proof = new bytes32[](proofLength);
for (uint256 i = 0; i < proofLength; i++) {
uint256 pairIndex = (index % 2 == 0) ? index + 1 : index - 1;
if (pairIndex < layers[i].length) {
proof[i] = layers[i][pairIndex];
} else {
proof[i] = bytes32(0);
}
index = index / 2;
}
return proof;
}
function indexOf(bytes32[] memory elements, bytes32 element) internal pure returns (uint256) {
for (uint256 i = 0; i < elements.length; i++) {
if (elements[i] == element) {
return i;
}
}
return type(uint256).max;
}
function verifyProof(bytes32[] memory proof, bytes32 root, bytes32 leaf) internal pure returns (bool) {
bytes32 computedHash = leaf;
for (uint256 i = 0; i < proof.length; i++) {
computedHash = combinedHash(computedHash, proof[i]);
}
return computedHash == root;
}
}