diff --git a/contracts/libs/SetMap.sol b/contracts/libs/SetMap.sol index bb34b59..8be47fb 100644 --- a/contracts/libs/SetMap.sol +++ b/contracts/libs/SetMap.sol @@ -279,7 +279,11 @@ library SetMap { internal returns (bool) { - return _set(map, key).add(value); + bool success = _set(map, key).add(value); + if (success) { + map._keys.add(Bytes32AddressSetMapKey.unwrap(key)); + } + return success; } /// @notice Removes a single value from an Bytes32AddressSetMap @@ -294,7 +298,12 @@ library SetMap { internal returns (bool) { - return _set(map, key).remove(value); + EnumerableSet.AddressSet storage set = _set(map, key); + bool success = _set(map, key).remove(value); + if (success && set.length() == 0) { + map._keys.remove(Bytes32AddressSetMapKey.unwrap(key)); + } + return success; } /// @notice Clears values for a key. diff --git a/contracts/libs/TestSetMap.sol b/contracts/libs/TestSetMap.sol new file mode 100644 index 0000000..e232a93 --- /dev/null +++ b/contracts/libs/TestSetMap.sol @@ -0,0 +1,157 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.0; + +import "./SetMap.sol"; + +// exposes public functions for testing +contract TestBytes32SetMap { + using SetMap for SetMap.Bytes32SetMap; + + event OperationResult(bool result); + + SetMap.Bytes32SetMap private _set; + + function values(SetMap.Bytes32SetMapKey key, + address addr) + public + view + returns (bytes32[] memory) + { + return _set.values(key, addr); + } + + function add(SetMap.Bytes32SetMapKey key, + address addr, + bytes32 value) + public + { + bool result = _set.add(key, addr, value); + emit OperationResult(result); + } + + function remove(SetMap.Bytes32SetMapKey key, + address addr, + bytes32 value) + public + { + bool result = _set.remove(key, addr, value); + emit OperationResult(result); + } + + function clear(SetMap.Bytes32SetMapKey key) + public + { + _set.clear(key); + } + + function length(SetMap.Bytes32SetMapKey key, + address addr) + public + view + returns (uint256) + { + return _set.length(key, addr); + } +} + +contract TestAddressBytes32SetMap { + using SetMap for SetMap.AddressBytes32SetMap; + + event OperationResult(bool result); + + SetMap.AddressBytes32SetMap private _set; + + function values(SetMap.AddressBytes32SetMapKey key) + public + view + returns (bytes32[] memory) + { + return _set.values(key); + } + + function add(SetMap.AddressBytes32SetMapKey key, + bytes32 value) + public + { + bool result = _set.add(key, value); + emit OperationResult(result); + } + + function remove(SetMap.AddressBytes32SetMapKey key, + bytes32 value) + public + { + bool result = _set.remove(key, value); + emit OperationResult(result); + } + + function clear(SetMap.AddressBytes32SetMapKey key) + public + { + _set.clear(key); + } +} + +contract TestBytes32AddressSetMap { + using EnumerableSet for EnumerableSet.Bytes32Set; + using SetMap for SetMap.Bytes32AddressSetMap; + + event OperationResult(bool result); + + SetMap.Bytes32AddressSetMap private _set; + + function keys() + view + public + returns (bytes32[] memory) + { + return _set._keys.values(); + } + + function values(SetMap.Bytes32AddressSetMapKey key) + public + view + returns (address[] memory) + { + return _set.values(key); + } + + function contains(SetMap.Bytes32AddressSetMapKey key, + address addr) + public + view + returns (bool) + { + return _set.contains(key, addr); + } + + function length(SetMap.Bytes32AddressSetMapKey key) + public + view + returns (uint256) + { + return _set.length(key); + } + + function add(SetMap.Bytes32AddressSetMapKey key, + address value) + public + { + bool result = _set.add(key, value); + emit OperationResult(result); + } + + function remove(SetMap.Bytes32AddressSetMapKey key, + address value) + public + { + bool result = _set.remove(key, value); + emit OperationResult(result); + } + + function clear(SetMap.Bytes32AddressSetMapKey key) + public + { + _set.clear(key); + } +} diff --git a/test/SetMap.test.js b/test/SetMap.test.js new file mode 100644 index 0000000..d32469b --- /dev/null +++ b/test/SetMap.test.js @@ -0,0 +1,198 @@ +const { ethers } = require("hardhat") +const { expect } = require("chai") +const { hexlify, randomBytes } = ethers.utils +const { exampleAddress } = require("./examples") + +describe("SetMap", function () { + let account + let key + let value + let contract + + describe("Bytes32SetMap", function () { + beforeEach(async function () { + let Bytes32SetMap = await ethers.getContractFactory("TestBytes32SetMap") + contract = await Bytes32SetMap.deploy() + ;[account] = await ethers.getSigners() + key = randomBytes(32) + value = randomBytes(32) + }) + + it("starts empty", async function () { + await expect(await contract.values(key, account.address)).to.deep.equal( + [] + ) + }) + + it("adds a key/address and value", async function () { + await expect(contract.add(key, account.address, value)) + .to.emit(contract, "OperationResult") + .withArgs(true) + await expect(await contract.values(key, account.address)).to.deep.equal([ + hexlify(value), + ]) + }) + + it("removes a value for key/address", async function () { + let value1 = randomBytes(32) + await contract.add(key, account.address, value) + await contract.add(key, account.address, value1) + await expect(contract.remove(key, account.address, value)) + .to.emit(contract, "OperationResult") + .withArgs(true) + await expect(await contract.values(key, account.address)).to.deep.equal([ + hexlify(value1), + ]) + }) + + it("clears all values for a key", async function () { + let key1 = randomBytes(32) + let value1 = randomBytes(32) + let value2 = randomBytes(32) + await contract.add(key, account.address, value) + await contract.add(key, account.address, value1) + await contract.add(key, account.address, value2) + await contract.add(key1, account.address, value) + await expect(contract.clear(key)) + await expect(await contract.values(key, account.address)).to.deep.equal( + [] + ) + await expect(await contract.values(key1, account.address)).to.deep.equal([ + hexlify(value), + ]) + }) + + it("gets the length of values for a key/address", async function () { + let value1 = randomBytes(32) + let value2 = randomBytes(32) + await contract.add(key, account.address, value) + await contract.add(key, account.address, value1) + await contract.add(key, account.address, value2) + await expect(await contract.length(key, account.address)).to.equal(3) + }) + }) + + describe("AddressBytes32SetMap", function () { + beforeEach(async function () { + let AddressBytes32SetMap = await ethers.getContractFactory( + "TestAddressBytes32SetMap" + ) + contract = await AddressBytes32SetMap.deploy() + ;[account, account1] = await ethers.getSigners() + key = account.address + value = randomBytes(32) + }) + + it("starts empty", async function () { + await expect(await contract.values(key)).to.deep.equal([]) + }) + + it("adds a key/address and value", async function () { + await expect(contract.add(key, value)) + .to.emit(contract, "OperationResult") + .withArgs(true) + await expect(await contract.values(key)).to.deep.equal([hexlify(value)]) + }) + + it("removes a value for key/address", async function () { + let value1 = randomBytes(32) + await contract.add(key, value) + await contract.add(key, value1) + await expect(contract.remove(key, value)) + .to.emit(contract, "OperationResult") + .withArgs(true) + await expect(await contract.values(key)).to.deep.equal([hexlify(value1)]) + }) + + it("clears all values for a key", async function () { + let key1 = account1.address + let value1 = randomBytes(32) + let value2 = randomBytes(32) + await contract.add(key, value) + await contract.add(key, value1) + await contract.add(key, value2) + await contract.add(key1, value) + await expect(contract.clear(key)) + await expect(await contract.values(key)).to.deep.equal([]) + await expect(await contract.values(key1)).to.deep.equal([hexlify(value)]) + }) + }) + + describe("Bytes32AddressSetMap", function () { + beforeEach(async function () { + let Bytes32AddressSetMap = await ethers.getContractFactory( + "TestBytes32AddressSetMap" + ) + contract = await Bytes32AddressSetMap.deploy() + ;[account] = await ethers.getSigners() + key = randomBytes(32) + value = exampleAddress() + }) + + it("starts empty", async function () { + await expect(await contract.values(key)).to.deep.equal([]) + }) + + it("adds a key/address and value", async function () { + await expect(contract.add(key, value)) + .to.emit(contract, "OperationResult") + .withArgs(true) + await expect(await contract.values(key)).to.deep.equal([value]) + }) + + it("returns list of keys", async function () { + let key1 = randomBytes(32) + let value1 = exampleAddress() + await contract.add(key, value) + await contract.add(key, value1) + await contract.add(key1, value) + await contract.add(key1, value1) + await expect(await contract.keys()).to.deep.equal([ + hexlify(key), + hexlify(key1), + ]) + }) + + it("contains a key/value pair", async function () { + let key1 = randomBytes(32) + let value1 = exampleAddress() + await contract.add(key, value) + await contract.add(key1, value1) + await expect(await contract.contains(key, value)).to.equal(true) + await expect(await contract.contains(key1, value1)).to.equal(true) + await expect(await contract.contains(key1, value)).to.equal(false) + }) + + it("removes a value for key/address", async function () { + let value1 = exampleAddress() + await contract.add(key, value) + await contract.add(key, value1) + await expect(contract.remove(key, value)) + .to.emit(contract, "OperationResult") + .withArgs(true) + await expect(await contract.values(key)).to.deep.equal([value1]) + }) + + it("clears all values for a key", async function () { + let key1 = randomBytes(32) + let value1 = exampleAddress() + let value2 = exampleAddress() + await contract.add(key, value) + await contract.add(key, value1) + await contract.add(key, value2) + await contract.add(key1, value) + await expect(contract.clear(key)) + await expect(await contract.values(key)).to.deep.equal([]) + await expect(await contract.values(key1)).to.deep.equal([value]) + }) + + it("gets the length of values for a key/address", async function () { + let value1 = exampleAddress() + let value2 = exampleAddress() + await contract.add(key, value) + await contract.add(key, value1) + await contract.add(key, value2) + await expect(await contract.length(key)).to.equal(3) + }) + }) +}) diff --git a/test/examples.js b/test/examples.js index 1bc864a..50caa46 100644 --- a/test/examples.js +++ b/test/examples.js @@ -1,7 +1,7 @@ const { ethers } = require("hardhat") const { hours } = require("./time") const { currentTime } = require("./evm") -const { hexlify, randomBytes } = ethers.utils +const { getAddress, hexlify, randomBytes } = ethers.utils const exampleRequest = async () => { const now = await currentTime() @@ -37,5 +37,8 @@ const exampleLock = async () => { expiry: now + hours(1), } } +const exampleAddress = () => { + return getAddress(hexlify(randomBytes(20))) +} -module.exports = { exampleRequest, exampleLock } +module.exports = { exampleRequest, exampleLock, exampleAddress }