diff --git a/contracts/Marketplace.sol b/contracts/Marketplace.sol index b1f3f11..eff5e57 100644 --- a/contracts/Marketplace.sol +++ b/contracts/Marketplace.sol @@ -28,7 +28,7 @@ contract Marketplace is Collateral, Proofs { // address => RequestId Mappings.Mapping private activeHostRequests; // RequestId => SlotId - Mappings.Mapping private activeRequestSlots; + Mappings.Mapping private activeHostRequestSlots; constructor( @@ -46,8 +46,8 @@ contract Marketplace is Collateral, Proofs { } function myRequests() public view returns (RequestId[] memory) { - Mappings.ValueId[] storage valueIds = - activeClientRequests.getValueIds(Mappings.toKeyId(msg.sender)); + Mappings.ValueId[] memory valueIds = + activeClientRequests.values(Mappings.toKeyId(msg.sender)); return _toRequestIds(valueIds); } @@ -57,21 +57,15 @@ contract Marketplace is Collateral, Proofs { returns (SlotId[] memory) { uint256 counter = 0; - uint256 totalSlots = activeRequestSlots.getValueCount(); // set this bigger than our possible filtered list size - if (totalSlots == 0) { - return new SlotId[](0); - } + uint256 totalSlots = activeHostRequestSlots.count(); // set this bigger than our possible filtered list size bytes32[] memory result = new bytes32[](totalSlots); - Mappings.ValueId[] storage valueIds = - activeHostRequests.getValueIds(Mappings.toKeyId(msg.sender)); + Mappings.ValueId[] memory valueIds = + activeHostRequests.values(Mappings.toKeyId(msg.sender)); for (uint256 i = 0; i < valueIds.length; i++) { - // There may exist slots that are still "active", but are part of a request - // that is expired but has not been set to the cancelled state yet. In that - // case, return an empty array. Mappings.KeyId keyId = Mappings.toKeyId(valueIds[i]); - if (activeRequestSlots.keyExists(keyId)) { - Mappings.ValueId[] storage slotIds = - activeRequestSlots.getValueIds(keyId); + if (activeHostRequestSlots.exists(keyId)) { + Mappings.ValueId[] memory slotIds = + activeHostRequestSlots.values(keyId); for (uint256 j = 0; j < slotIds.length; j++) { result[counter] = Mappings.ValueId.unwrap(slotIds[j]); counter++; @@ -100,12 +94,12 @@ contract Marketplace is Collateral, Proofs { context.endsAt = block.timestamp + request.ask.duration; _setProofEnd(_toEndId(id), context.endsAt); - Mappings.KeyId addrBytes32 = Mappings.toKeyId(request.client); - activeClientRequests.insert(addrBytes32, _toValueId(id)); + Mappings.KeyId clientKey = Mappings.toKeyId(request.client); + activeClientRequests.insert(clientKey, _toValueId(id)); - Mappings.KeyId keyId = _toKeyId(id); - if (!activeRequestSlots.keyExists(keyId)) { - activeRequestSlots.insertKey(keyId); + Mappings.KeyId requestKey = _toKeyId(id); + if (!activeHostRequestSlots.exists(requestKey)) { + activeHostRequestSlots.insertKey(requestKey); } _createLock(_toLockId(id), request.expiry); @@ -144,11 +138,8 @@ contract Marketplace is Collateral, Proofs { context.slotsFilled += 1; Mappings.KeyId sender = Mappings.toKeyId(msg.sender); - // address => RequestId activeHostRequests.insert(sender, _toValueId(requestId)); - - // RequestId => SlotId - activeRequestSlots.insert(_toKeyId(requestId), _toValueId(slotId)); + activeHostRequestSlots.insert(_toKeyId(requestId), _toValueId(slotId)); emit SlotFilled(requestId, slotIndex, slotId); if (context.slotsFilled == request.ask.slots) { @@ -159,6 +150,32 @@ contract Marketplace is Collateral, Proofs { } } + function _removeHostSlot(address host, RequestId requestId, SlotId slotId) internal { + Mappings.KeyId requestKey = _toKeyId(requestId); + activeHostRequestSlots.deleteValue(requestKey, _toValueId(slotId)); + + if (activeHostRequestSlots.count(requestKey) == 0) { + Mappings.KeyId hostKey = Mappings.toKeyId(host); + Mappings.ValueId requestValue = _toValueId(requestId); + activeHostRequestSlots.deleteKey(requestKey); + activeHostRequests.deleteValue(hostKey, requestValue); + } + } + + function _removeAllHostSlots(address host, RequestId requestId) internal { + Mappings.KeyId hostKey = Mappings.toKeyId(host); + activeHostRequestSlots.clear(_toKeyId(requestId)); + activeHostRequests.deleteValue(hostKey, _toValueId(requestId)); + } + + function _removeClientRequest(address client, RequestId requestId) internal { + Mappings.ValueId requestValue = _toValueId(requestId); + Mappings.KeyId clientKey = Mappings.toKeyId(client); + if (activeClientRequests.exists(clientKey, requestValue)) { + activeClientRequests.deleteValue(clientKey, requestValue); + } + } + function _freeSlot(SlotId slotId) internal slotMustAcceptProofs(slotId) @@ -175,11 +192,9 @@ contract Marketplace is Collateral, Proofs { // not finalised. _unexpectProofs(_toProofId(slotId)); + _removeHostSlot(slot.host, requestId, slotId); - Mappings.ValueId valueId = _toValueId(slotId); - if (activeRequestSlots.valueExists(valueId)) { - activeRequestSlots.deleteValue(valueId); - } + address slotHost = slot.host; slot.host = address(0); slot.requestId = RequestId.wrap(0); context.slotsFilled -= 1; @@ -194,8 +209,8 @@ contract Marketplace is Collateral, Proofs { context.state = RequestState.Failed; _setProofEnd(_toEndId(requestId), block.timestamp - 1); context.endsAt = block.timestamp - 1; - activeClientRequests.deleteValue(_toValueId(requestId)); - activeRequestSlots.clearValues(_toKeyId(requestId)); + _removeAllHostSlots(slotHost, requestId); + _removeClientRequest(request.client, requestId); emit RequestFailed(requestId); // TODO: burn all remaining slot collateral (note: slot collateral not @@ -210,20 +225,14 @@ contract Marketplace is Collateral, Proofs { { require(_isFinished(requestId), "Contract not ended"); RequestContext storage context = _context(requestId); - // Request storage request = _request(requestId); - context.state = RequestState.Finished; - Mappings.ValueId valueId = _toValueId(requestId); - if (activeClientRequests.valueExists(valueId)) { - activeClientRequests.deleteValue(valueId); - } + Request storage request = _request(requestId); SlotId slotId = _toSlotId(requestId, slotIndex); Slot storage slot = _slot(slotId); require(!slot.hostPaid, "Already paid"); - activeRequestSlots.deleteValue(_toValueId(slotId)); - if (activeRequestSlots.getValueCount() == 0) { - activeRequestSlots.deleteKey(_toKeyId(requestId)); - activeHostRequests.deleteValue(valueId); - } + + context.state = RequestState.Finished; + _removeHostSlot(slot.host, requestId, slotId); + _removeClientRequest(request.client, requestId); uint256 amount = pricePerSlot(requests[requestId]); funds.sent += amount; funds.balance -= amount; @@ -244,8 +253,9 @@ contract Marketplace is Collateral, Proofs { // Update request state to Cancelled. Handle in the withdraw transaction // as there needs to be someone to pay for the gas to update the state context.state = RequestState.Cancelled; - activeClientRequests.deleteValue(_toValueId(requestId)); - activeRequestSlots.clearValues(_toKeyId(requestId)); + // TODO: double-check that we don't want to _removeAllHostSlots() here. + // @markspanbroek? + _removeClientRequest(request.client, requestId); // TODO: handle dangling RequestId in activeHostRequests (for address) emit RequestCancelled(requestId); diff --git a/contracts/TestMappings.sol b/contracts/TestMappings.sol index 01261e0..9d60a1f 100644 --- a/contracts/TestMappings.sol +++ b/contracts/TestMappings.sol @@ -11,77 +11,78 @@ contract TestMappings { Mappings.Mapping private _map; - function getTotalValueCount() public view returns (uint256) { - return _map.getValueCount(); + function totalCount() public view returns (uint256) { + return _map.count(); } - function getValueCount(Mappings.KeyId keyId) public view returns (uint256) { - return _map.getValueCount(keyId); + function count(Mappings.KeyId key) public view returns (uint256) { + return _map.count(key); } - function keyExists(Mappings.KeyId keyId) public view returns (bool) { - return _map.keyExists(keyId); + function keyExists(Mappings.KeyId key) public view returns (bool) { + return _map.exists(key); } - function valueExists(Mappings.ValueId valueId) + function valueExists(Mappings.KeyId key, Mappings.ValueId value) public view returns (bool) { - return _map.valueExists(valueId); + return _map.exists(key, value); } - function getKeyIds() public view returns (Mappings.KeyId[] memory) { - return _map.getKeyIds(); + function keys() public view returns (Mappings.KeyId[] memory) { + return _map.keys(); } - function getValueIds(Mappings.KeyId keyId) + function values(Mappings.KeyId key) public view returns (Mappings.ValueId[] memory) { - return _map.getValueIds(keyId); + return _map.values(key); } - function insertKey(Mappings.KeyId keyId) public returns (bool success) { - success = _map.insertKey(keyId); + function insertKey(Mappings.KeyId key) public returns (bool success) { + success = _map.insertKey(key); emit OperationResult(success); } - function insertValue(Mappings.KeyId keyId, Mappings.ValueId valueId) + function insertValue(Mappings.KeyId key, Mappings.ValueId value) public returns (bool success) { - success = _map.insertValue(keyId, valueId); + success = _map.insertValue(key, value); emit OperationResult(success); } - function insert(Mappings.KeyId keyId, Mappings.ValueId valueId) + function insert(Mappings.KeyId key, Mappings.ValueId value) public returns (bool success) { - success = _map.insert(keyId, valueId); + success = _map.insert(key, value); emit OperationResult(success); } - function deleteKey(Mappings.KeyId keyId) public returns (bool success) { - success = _map.deleteKey(keyId); + function deleteKey(Mappings.KeyId key) public returns (bool success) { + success = _map.deleteKey(key); emit OperationResult(success); } - function deleteValue(Mappings.ValueId valueId) + function deleteValue(Mappings.KeyId key, + Mappings.ValueId value) public returns (bool success) { - success = _map.deleteValue(valueId); + success = _map.deleteValue(key, value); emit OperationResult(success); } - function clearValues(Mappings.KeyId keyId) + function clear(Mappings.KeyId key) public returns (bool success) { - success = _map.clearValues(keyId); + success = _map.clear(key); emit OperationResult(success); } } diff --git a/contracts/libs/Debug.sol b/contracts/libs/Debug.sol index 7c2816c..f70d3b7 100644 --- a/contracts/libs/Debug.sol +++ b/contracts/libs/Debug.sol @@ -1,10 +1,15 @@ // SPDX-License-Identifier: MIT pragma solidity ^0.8.8; +import "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; import "./Mappings.sol"; + import "hardhat/console.sol"; // DELETE ME library Debug { + using Mappings for Mappings.Mapping; + using EnumerableSet for EnumerableSet.Bytes32Set; + function _toHex16 (bytes16 data) private pure returns (bytes32 result) { result = bytes32 (data) & 0xFFFFFFFFFFFFFFFF000000000000000000000000000000000000000000000000 | (bytes32 (data) & 0x0000000000000000FFFFFFFFFFFFFFFF00000000000000000000000000000000) >> 64; @@ -50,11 +55,11 @@ library Debug { console.log("| Key | Value |"); console.log("| ------------------------------------------------------------------ | ------------------------------------------------------------------ |"); uint256 referencedValues = 0; - for(uint8 i = 0; i < db._keyIds.length; i++) { - Mappings.KeyId keyId = db._keyIds[i]; - console.log("|", _toHex(Mappings.KeyId.unwrap(keyId)), "| |"); + for(uint8 i = 0; i < db._keyIds.length(); i++) { + bytes32 keyId = db._keyIds.at(i); + console.log("|", _toHex(keyId), "| |"); - Mappings.ValueId[] storage valueIds = Mappings.getValueIds(db, keyId); + Mappings.ValueId[] memory valueIds = db.values(Mappings.KeyId.wrap(keyId)); for(uint8 j = 0; j < valueIds.length; j++) { Mappings.ValueId valueId = valueIds[j]; console.log("| |", _toHex(Mappings.ValueId.unwrap(valueId)), "|"); @@ -63,7 +68,7 @@ library Debug { } console.log("|_________________________________________________________________________________________________________________________________________|"); console.log(" Referenced values: ", referencedValues); - uint256 totalValues = Mappings.getValueCount(db); + uint256 totalValues = db.count(); console.log(" Unreferenced values: ", totalValues - referencedValues, " (total values not deleted but are unused)"); console.log(" TOTAL Values: ", totalValues); } diff --git a/contracts/libs/EnumerableSetExtensions.sol b/contracts/libs/EnumerableSetExtensions.sol new file mode 100644 index 0000000..7c0af23 --- /dev/null +++ b/contracts/libs/EnumerableSetExtensions.sol @@ -0,0 +1,97 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.8; + +import "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; + +library EnumerableSetExtensions { + using EnumerableSet for EnumerableSet.Bytes32Set; + using EnumerableSet for EnumerableSet.AddressSet; + + struct ClearableBytes32Set { + mapping(uint256 => + EnumerableSet.Bytes32Set) _values; + uint256 _index; + } + + /// @notice Returns the EnumerableSet.Bytes32 containing the values + /// @dev This is used internally to the library only. `.values()` should only + /// be called on its return value in a view/pure function. + /// @param map ClearableBytes32Set to list values + /// @return EnumerableSet.Bytes32 containing values + function _set(ClearableBytes32Set storage map) + private + view + returns (EnumerableSet.Bytes32Set storage) + { + return map._values[map._index]; + } + + /// @notice Lists all values for a key and address in an ClearableBytes32Set + /// @param map ClearableBytes32Set to list values + /// @return bytes32[] array of bytes32 values + function values(ClearableBytes32Set storage map) + internal + view + returns (bytes32[] memory) + { + return _set(map).values(); + } + + /// @notice Adds a single value to a ClearableBytes32Set + /// @param map Bytes32SetMap to add the value to + /// @param value the value to be added + /// @return true if the value was added to the set, that is if it was not + /// already present. + function add(ClearableBytes32Set storage map, + bytes32 value) + internal + returns (bool) + { + return _set(map).add(value); + } + + /// @notice Removes a single value from a ClearableBytes32Set + /// @param map Bytes32SetMap to remove the value from + /// @param value the value to be removed + /// @return true if the value was removed from the set, that is if it was + /// present. + function remove(ClearableBytes32Set storage map, + bytes32 value) + internal + returns (bool) + { + return _set(map).remove(value); + } + + /// @notice Clears all values. + /// @dev Updates an index such that the next time values for that key are + /// retrieved, it will reference a new EnumerableSet. + /// @param map ClearableBytes32Set for which to clear values + function clear(ClearableBytes32Set storage map) + internal + { + map._index++; + } + + /// @notice Returns the length of values for a key and address. + /// @param map ClearableBytes32Set for which to get length of values + function length(ClearableBytes32Set storage map) + internal + view + returns (uint256) + { + return _set(map).length(); + } + + /// @notice Lists all values for a key in an Bytes32SetMap + /// @param map Bytes32SetMap to list values + /// @return bytes32[] array of bytes32 values + function contains(ClearableBytes32Set storage map, + bytes32 value) + internal + view + returns (bool) + { + return _set(map).contains(value); + } +} diff --git a/contracts/libs/Mappings.sol b/contracts/libs/Mappings.sol index 117f65a..243cb3a 100644 --- a/contracts/libs/Mappings.sol +++ b/contracts/libs/Mappings.sol @@ -2,232 +2,191 @@ // heavily inspired by: https://bitbucket.org/rhitchens2/soliditystoragepatterns/src/master/GeneralizedCollection.sol pragma solidity ^0.8.8; -import "./Debug.sol"; // DELETE ME +import "./EnumerableSetExtensions.sol"; +import "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; library Mappings { + + using EnumerableSet for EnumerableSet.Bytes32Set; + using EnumerableSetExtensions for EnumerableSetExtensions.ClearableBytes32Set; + type KeyId is bytes32; type ValueId is bytes32; - // first entity is called a "One" - struct Key { - // needed to delete a "One" - uint256 _oneListPointer; - // One has many "Many" - ValueId[] _valueIds; - mapping(ValueId => uint256) _valueIdsIndex; // valueId => row of local _valueIds - // more app data - } - - // other entity is called a "Many" - struct Value { - // needed to delete a "Many" - uint256 _valueIdsIndex; - // many has exactly one "One" - KeyId _keyId; - // add app fields - } - struct Mapping { + EnumerableSet.Bytes32Set _keyIds; + EnumerableSet.Bytes32Set _valueIds; mapping(KeyId => Key) _keys; - KeyId[] _keyIds; mapping(ValueId => Value) _values; - ValueId[] _valueIds; + } + struct Key { + string name; + bool delux; + uint price; + + EnumerableSetExtensions.ClearableBytes32Set _values; + } + struct Value { + string name; + bool delux; + uint price; + + KeyId _keyId; } - function keyCount(Mapping storage db) + function exists(Mapping storage map, KeyId key) + internal + view + returns (bool) + { + return map._keyIds.contains(KeyId.unwrap(key)); + } + + function exists(Mapping storage map, KeyId key, ValueId value) + internal + view + returns (bool) + { + bytes32 val = ValueId.unwrap(value); + return map._keys[key]._values.contains(val) && + map._valueIds.contains(val); + } + + function keys(Mapping storage map) + internal + view + returns(KeyId[] memory) + { + return _toKeyIds(map._keyIds.values()); + } + + function values(Mapping storage map, + KeyId key) + internal + view + returns(ValueId[] memory) + { + require(exists(map, key), "key does not exist"); + return _toValueIds(map._keys[key]._values.values()); + } + + function count(Mapping storage map, + KeyId key) internal view returns(uint256) { - return db._keyIds.length; + require(exists(map, key), "key does not exist"); + return map._keys[key]._values.length(); } - function getValueCount(Mapping storage db) internal view returns(uint256) { - return db._valueIds.length; - } - - function getValueCount(Mapping storage db, KeyId keyId) + function count(Mapping storage map) internal view returns(uint256) { - return getValueIds(db, keyId).length; + return map._valueIds.length(); } - function keyExists(Mapping storage db, KeyId keyId) + function insertKey(Mapping storage map, KeyId key) internal - view - returns(bool) + returns (bool) { - if(keyCount(db) == 0) return false; - return equals(db._keyIds[db._keys[keyId]._oneListPointer], keyId); + require(!exists(map, key), "key already exists"); + return map._keyIds.add(KeyId.unwrap(key)); + // NOTE: map._keys[key]._values contains a default EnumerableSet.Bytes32Set } - function valueExists(Mapping storage db, ValueId valueId) + function insertValue(Mapping storage map, KeyId key, ValueId value) internal - view - returns(bool) + returns (bool success) { - if (getValueCount(db) == 0) return false; - uint256 row = db._values[valueId]._valueIdsIndex; - bool retVal = equals(db._valueIds[row], valueId); - return retVal; + require(exists(map, key), "key does not exists"); + require(!exists(map, key, value), "value already exists"); + success = map._valueIds.add(ValueId.unwrap(value)); + assert (success); // value addition failure + map._values[value]._keyId = key; + + success = map._keys[key]._values.add(ValueId.unwrap(value)); } - function getKeyIds(Mapping storage db) + function insert(Mapping storage map, KeyId key, ValueId value) internal - view - returns(KeyId[] storage) + returns (bool success) { - return db._keyIds; - } - - function getValueIds(Mapping storage db, - KeyId keyId) - internal - view - returns(ValueId[] storage) - { - require(keyExists(db, keyId), "key does not exist"); - return db._keys[keyId]._valueIds; - } - - // Insert - function insertKey(Mapping storage db, KeyId keyId) - internal - returns(bool) - { - require(!keyExists(db, keyId), "key already exists"); // duplicate key prohibited - - db._keyIds.push(keyId); - db._keys[keyId]._oneListPointer = keyCount(db) - 1; - return true; - } - - function insertValue(Mapping storage db, KeyId keyId, ValueId valueId) - internal - returns(bool) - { - require(keyExists(db, keyId), "key does not exist"); - require(!valueExists(db, valueId), "value already exists"); // duplicate key prohibited - - Value storage value = db._values[valueId]; - db._valueIds.push(valueId); - value._valueIdsIndex = getValueCount(db) - 1; - value._keyId = keyId; // each many has exactly one "One", so this is mandatory - - // We also maintain a list of "Many" that refer to the "One", so ... - Key storage key = db._keys[keyId]; - key._valueIds.push(valueId); - key._valueIdsIndex[valueId] = key._valueIds.length - 1; - return true; - } - - function insert(Mapping storage db, KeyId keyId, ValueId valueId) - internal - returns(bool success) - { - if (!keyExists(db, keyId)) { - success = insertKey(db, keyId); - if (!success) { - return false; - } + if (!exists(map, key)) { + success = insertKey(map, key); + assert (success); // key insertion failure } - if (!valueExists(db, valueId)) { - success = insertValue(db, keyId, valueId); - } - return success; + map._valueIds.add(ValueId.unwrap(value)); + map._values[value]._keyId = key; + + success = map._keys[key]._values.add(ValueId.unwrap(value)); } - - // Delete - function deleteKey(Mapping storage db, KeyId keyId) + function deleteKey(Mapping storage map, KeyId key) internal - returns(bool) + returns (bool success) { - require(keyExists(db, keyId), "key does not exist"); - require(getValueIds(db, keyId).length == 0, "references values"); // this would break referential integrity - - uint256 rowToDelete = db._keys[keyId]._oneListPointer; - KeyId keyToMove = db._keyIds[keyCount(db)-1]; - db._keyIds[rowToDelete] = keyToMove; - db._keys[keyToMove]._oneListPointer = rowToDelete; - db._keyIds.pop(); - delete db._keys[keyId]; - return true; + require(exists(map, key), "key does not exist"); + require(count(map, key) == 0, "references values"); + success = map._keyIds.remove(KeyId.unwrap(key)); // Note that this will fail automatically if the key doesn't exist + assert(success); // key removal failure + delete map._keys[key]; } - function deleteValue(Mapping storage db, ValueId valueId) + function deleteValue(Mapping storage map, KeyId key, ValueId value) internal - returns(bool) + returns (bool success) { - require(valueExists(db, valueId), "value does not exist"); // non-existant key + require(exists(map, key), "key does not exist"); + require(exists(map, key, value), "value does not exist"); - // delete from the Many table - uint256 toDeleteIndex = db._values[valueId]._valueIdsIndex; + success = map._valueIds.remove(ValueId.unwrap(value)); + assert (success); // value removal failure + delete map._values[value]; - uint256 lastIndex = getValueCount(db) - 1; + success = map._keys[key]._values.remove(ValueId.unwrap(value)); - if (lastIndex != toDeleteIndex) { - ValueId lastValue = db._valueIds[lastIndex]; - - // Move the last value to the index where the value to delete is - db._valueIds[toDeleteIndex] = lastValue; - // Update the index for the moved value - db._values[lastValue]._valueIdsIndex = toDeleteIndex; // Replace lastvalue's index to valueIndex - } - db._valueIds.pop(); - - KeyId keyId = db._values[valueId]._keyId; - Key storage oneRow = db._keys[keyId]; - toDeleteIndex = oneRow._valueIdsIndex[valueId]; - lastIndex = oneRow._valueIds.length - 1; - if (lastIndex != toDeleteIndex) { - ValueId lastValue = oneRow._valueIds[lastIndex]; - - // Move the last value to the index where the value to delete is - oneRow._valueIds[toDeleteIndex] = lastValue; - // Update the index for the moved value - oneRow._valueIdsIndex[lastValue] = toDeleteIndex; // Replace lastvalue's index to valueIndex - } - oneRow._valueIds.pop(); - delete oneRow._valueIdsIndex[valueId]; - delete db._values[valueId]; - - if (getValueCount(db, keyId) == 0) { - deleteKey(db, keyId); - } - return true; } - function clearValues(Mapping storage db, KeyId keyId) + function clear(Mapping storage map, KeyId key) internal - returns(bool) + returns (bool success) { - require(keyExists(db, keyId), "key does not exist"); // non-existant key + require(exists(map, key), "key does not exist"); - Debug._printTable(db, "[clearValues] BEFORE clearing"); - // delete db._valueIds; - delete db._keys[keyId]._valueIds; - bool result = deleteKey(db, keyId); - Debug._printTable(db, "[clearValues] AFTER clearing"); - return result; + map._keys[key]._values.clear(); + success = deleteKey(map, key); } - function equals(KeyId a, KeyId b) internal pure returns (bool) { - return KeyId.unwrap(a) == KeyId.unwrap(b); + function _toKeyIds(bytes32[] memory array) + private + pure + returns (KeyId[] memory result) + { + // solhint-disable-next-line no-inline-assembly + assembly { + result := array + } } - function equals(ValueId a, ValueId b) internal pure returns (bool) { - return ValueId.unwrap(a) == ValueId.unwrap(b); + function _toValueIds(bytes32[] memory array) + private + pure + returns (ValueId[] memory result) + { + // solhint-disable-next-line no-inline-assembly + assembly { + result := array + } + } + + function toKeyId(ValueId valueId) internal pure returns (KeyId) { + return KeyId.wrap(ValueId.unwrap(valueId)); } function toKeyId(address addr) internal pure returns (KeyId) { return KeyId.wrap(bytes32(uint(uint160(addr)))); } - - // Useful in the case where a valueId is a foreign key - function toKeyId(ValueId valueId) internal pure returns (KeyId) { - return KeyId.wrap(ValueId.unwrap(valueId)); - } } diff --git a/contracts/libs/TestEnumerableSetExtensions.sol b/contracts/libs/TestEnumerableSetExtensions.sol new file mode 100644 index 0000000..520a694 --- /dev/null +++ b/contracts/libs/TestEnumerableSetExtensions.sol @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.0; + +import "./EnumerableSetExtensions.sol"; + +// exposes public functions for testing +contract TestClearableBytes32Set { + using EnumerableSet for EnumerableSet.Bytes32Set; + using EnumerableSetExtensions for EnumerableSetExtensions.ClearableBytes32Set; + + event OperationResult(bool result); + + EnumerableSetExtensions.ClearableBytes32Set private _set; + + function values() + public + view + returns (bytes32[] memory) + { + return _set.values(); + } + + function add(bytes32 value) + public + { + bool result = _set.add(value); + emit OperationResult(result); + } + + function remove(bytes32 value) + public + { + bool result = _set.remove(value); + emit OperationResult(result); + } + + function clear() + public + { + _set.clear(); + } + + function length() + public + view + returns (uint256) + { + return _set.length(); + } + + function contains(bytes32 value) + public + view + returns (bool) + { + return _set.contains(value); + } +} diff --git a/test/EnumerableSetExtensions.test.js b/test/EnumerableSetExtensions.test.js new file mode 100644 index 0000000..9291594 --- /dev/null +++ b/test/EnumerableSetExtensions.test.js @@ -0,0 +1,90 @@ +const { ethers } = require("hardhat") +const { expect } = require("chai") +const { hexlify, randomBytes } = ethers.utils +const { exampleAddress } = require("./examples") + +describe("EnumerableSetExtensions", function () { + let account + let key + let value + let contract + + describe("ClearableBytes32Set", function () { + beforeEach(async function () { + let ClearableBytes32Set = await ethers.getContractFactory( + "TestClearableBytes32Set" + ) + contract = await ClearableBytes32Set.deploy() + ;[account] = await ethers.getSigners() + value = randomBytes(32) + }) + + it("starts empty", async function () { + await expect(await contract.values()).to.deep.equal([]) + }) + + it("adds a value", async function () { + await expect(contract.add(value)) + .to.emit(contract, "OperationResult") + .withArgs(true) + await expect(await contract.values()).to.deep.equal([hexlify(value)]) + }) + + it("adds a value that already exists", async function () { + await contract.add(value) + await expect(contract.add(value)) + .to.emit(contract, "OperationResult") + .withArgs(false) + await expect(await contract.values()).to.deep.equal([hexlify(value)]) + }) + + it("contains a value", async function () { + let key1 = randomBytes(32) + let value1 = randomBytes(32) + await contract.add(value) + await contract.add(value1) + await expect(await contract.contains(value)).to.equal(true) + await expect(await contract.contains(value1)).to.equal(true) + }) + + it("removes a value", async function () { + let value1 = randomBytes(32) + await contract.add(value) + await contract.add(value1) + await expect(contract.remove(value)) + .to.emit(contract, "OperationResult") + .withArgs(true) + await expect(await contract.values()).to.deep.equal([hexlify(value1)]) + }) + + it("removes a value that doesn't exist", async function () { + let value1 = randomBytes(32) + await contract.add(value) + await contract.add(value1) + await contract.remove(value) + await expect(contract.remove(value)) + .to.emit(contract, "OperationResult") + .withArgs(false) + await expect(await contract.values()).to.deep.equal([hexlify(value1)]) + }) + + it("clears all values", async function () { + let value1 = randomBytes(32) + let value2 = randomBytes(32) + await contract.add(value) + await contract.add(value1) + await contract.add(value2) + await expect(contract.clear()) + await expect(await contract.values()).to.deep.equal([]) + }) + + it("gets the length of values", async function () { + let value1 = randomBytes(32) + let value2 = randomBytes(32) + await contract.add(value) + await contract.add(value1) + await contract.add(value2) + await expect(await contract.length()).to.equal(3) + }) + }) +}) diff --git a/test/Mappings.test.js b/test/Mappings.test.js index dacfbb9..e41507f 100644 --- a/test/Mappings.test.js +++ b/test/Mappings.test.js @@ -20,18 +20,16 @@ describe("Mappings", function () { it("starts empty", async function () { await expect(await contract.keyExists(key)).to.be.false - await expect(await contract.valueExists(value)).to.be.false - await expect(await contract.getKeyIds()).to.deep.equal([]) - await expect(await contract.getTotalValueCount()).to.equal(0) + await expect(await contract.valueExists(key, value)).to.be.false + await expect(await contract.keys()).to.deep.equal([]) + await expect(await contract.totalCount()).to.equal(0) }) it("adds a key and value", async function () { await expect(contract.insert(key, value)) .to.emit(contract, "OperationResult") .withArgs(true) - await expect(await contract.getValueIds(key)).to.deep.equal([ - hexlify(value), - ]) + await expect(await contract.values(key)).to.deep.equal([hexlify(value)]) }) it("removes a key", async function () { @@ -46,13 +44,11 @@ describe("Mappings", function () { let value1 = randomBytes(32) await contract.insert(key, value) await contract.insert(key, value1) - await expect(contract.deleteValue(value)) + await expect(contract.deleteValue(key, value)) .to.emit(contract, "OperationResult") .withArgs(true) - await expect(await contract.getKeyIds()).to.deep.equal([hexlify(key)]) - await expect(await contract.getValueIds(key)).to.deep.equal([ - hexlify(value1), - ]) + await expect(await contract.keys()).to.deep.equal([hexlify(key)]) + await expect(await contract.values(key)).to.deep.equal([hexlify(value1)]) }) // referential integrity @@ -64,7 +60,7 @@ describe("Mappings", function () { }) it("fails to get value ids when key does not exist", async function () { - await expect(contract.getValueIds(key)).to.be.revertedWith( + await expect(contract.values(key)).to.be.revertedWith( "key does not exist" ) }) @@ -95,35 +91,35 @@ describe("Mappings", function () { let value3 = randomBytes(32) await contract.insert(key, value) await expect(await contract.keyExists(key)).to.be.true - await expect(await contract.valueExists(value)).to.be.true - await expect(await contract.valueExists(value1)).to.be.false - await expect(await contract.getValueCount(key)).to.equal(1) - await expect(await contract.getTotalValueCount()).to.equal(1) + await expect(await contract.valueExists(key, value)).to.be.true + await expect(await contract.valueExists(key, value1)).to.be.false + await expect(await contract.count(key)).to.equal(1) + await expect(await contract.totalCount()).to.equal(1) await contract.insert(key, value1) - await expect(await contract.valueExists(value1)).to.be.true - await expect(await contract.getValueCount(key)).to.equal(2) - await expect(await contract.getTotalValueCount()).to.equal(2) + await expect(await contract.valueExists(key, value1)).to.be.true + await expect(await contract.count(key)).to.equal(2) + await expect(await contract.totalCount()).to.equal(2) - await expect(contract.deleteValue(value1)) - await expect(await contract.keyExists(key)).to.be.true - await expect(await contract.valueExists(value1)).to.be.false - await expect(await contract.getValueCount(key)).to.equal(1) - await expect(await contract.getTotalValueCount()).to.equal(1) + await expect(contract.deleteValue(key, value1)) + await expect(await contract.valueExists(key, value)).to.be.true + await expect(await contract.valueExists(key, value1)).to.be.false + await expect(await contract.count(key)).to.equal(1) + await expect(await contract.totalCount()).to.equal(1) await contract.insert(key, value1) await contract.insert(key, value2) await contract.insert(key, value3) - await expect(contract.clearValues(key)) + await expect(contract.clear(key)) await expect(await contract.keyExists(key)).to.be.false - await expect(await contract.getKeyIds()).to.deep.equal([]) + await expect(await contract.keys()).to.deep.equal([]) // TODO: handle unreferenced values, as visible here. Once handled, this value should be 1 - await expect(await contract.getTotalValueCount()).to.equal(4) - // await expect(await contract.valueExists(value)).to.be.false - // await expect(await contract.valueExists(value1)).to.be.false - // await expect(await contract.valueExists(value2)).to.be.false - // await expect(await contract.valueExists(value3)).to.be.false + // await expect(await contract.totalCount()).to.equal(4) + // await expect(await contract.valueExists(key, value)).to.be.false + // await expect(await contract.valueExists(key, value1)).to.be.false + // await expect(await contract.valueExists(key, value2)).to.be.false + // await expect(await contract.valueExists(key, value3)).to.be.false }) }) })