diff --git a/contracts/Marketplace.sol b/contracts/Marketplace.sol index b788aff..d13b4fc 100644 --- a/contracts/Marketplace.sol +++ b/contracts/Marketplace.sol @@ -7,6 +7,8 @@ import "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; import "./Collateral.sol"; import "./Proofs.sol"; +import "hardhat/console.sol"; + contract Marketplace is Collateral, Proofs { using EnumerableSet for EnumerableSet.Bytes32Set; @@ -19,7 +21,7 @@ contract Marketplace is Collateral, Proofs { mapping(RequestId => RequestContext) private requestContexts; mapping(SlotId => Slot) private slots; mapping(address => EnumerableSet.Bytes32Set) private activeRequests; - mapping(address => mapping(uint8 => EnumerableSet.Bytes32Set)) private activeSlots; + mapping(address => EnumerableSet.Bytes32Set) private activeSlots; mapping(address => uint8) private activeSlotsIdx; constructor( @@ -41,25 +43,33 @@ contract Marketplace is Collateral, Proofs { } function mySlots() public view returns (SlotId[] memory) { - EnumerableSet.Bytes32Set storage slotIds = _activeSlotsForHost(msg.sender); - return _toSlotIds(slotIds.values()); + return _activeSlotsForHost(msg.sender); } function _activeSlotsForHost(address host) internal view - returns (EnumerableSet.Bytes32Set storage) + returns (SlotId[] memory) { - mapping(uint8 => EnumerableSet.Bytes32Set) storage active = activeSlots[host]; - uint8 idx = activeSlotsIdx[host]; - return active[idx]; + return _toSlotIds(activeSlots[host].values()); } - /// @notice Clears active slots for a host - /// @dev Because there are no efficient ways to clear an EnumerableSet, an index is updated that points to a new instance. + function _equals(RequestId a, RequestId b) internal pure returns (bool) { + return RequestId.unwrap(a) == RequestId.unwrap(b); + } + + /// @notice Clears active slots for a host for a given request + /// @dev WARNING: This could potentially run out of gas if the number of slots is too high. Calling .values() copies to memory. /// @param host address of the host for which to clear the active slots - function _clearActiveSlots(address host) internal { - activeSlotsIdx[host] = activeSlotsIdx[host] + 1; + /// @param requestId identifies the request that the slot must belong to for it to be cleared + function _clearActiveSlotsForRequest(address host, RequestId requestId) internal { + SlotId[] memory slotIds = _activeSlotsForHost(host); + for (uint8 i = 0; i < slotIds.length; i++) { + RequestId slotRequestId = _getRequestIdForSlot(slotIds[i]); + if (_equals(slotRequestId, requestId)) { + activeSlots[host].remove(SlotId.unwrap(slotIds[i])); + } + } } function requestStorage(Request calldata request) @@ -113,7 +123,7 @@ contract Marketplace is Collateral, Proofs { slot.requestId = requestId; RequestContext storage context = _context(requestId); context.slotsFilled += 1; - _activeSlotsForHost(slot.host).add(SlotId.unwrap(slotId)); + activeSlots[slot.host].add(SlotId.unwrap(slotId)); emit SlotFilled(requestId, slotIndex, slotId); if (context.slotsFilled == request.ask.slots) { context.state = RequestState.Started; @@ -141,7 +151,7 @@ contract Marketplace is Collateral, Proofs { _unexpectProofs(_toProofId(slotId)); address slotHost = slot.host; - _activeSlotsForHost(slotHost).remove(SlotId.unwrap(slotId)); + activeSlots[slotHost].remove(SlotId.unwrap(slotId)); slot.host = address(0); slot.requestId = RequestId.wrap(0); context.slotsFilled -= 1; @@ -157,7 +167,7 @@ contract Marketplace is Collateral, Proofs { _setProofEnd(_toEndId(requestId), block.timestamp - 1); context.endsAt = block.timestamp - 1; activeRequests[request.client].remove(RequestId.unwrap(requestId)); - _clearActiveSlots(slotHost); + _clearActiveSlotsForRequest(slotHost, requestId); emit RequestFailed(requestId); // TODO: burn all remaining slot collateral (note: slot collateral not @@ -178,7 +188,7 @@ contract Marketplace is Collateral, Proofs { SlotId slotId = _toSlotId(requestId, slotIndex); Slot storage slot = _slot(slotId); require(!slot.hostPaid, "Already paid"); - _activeSlotsForHost(slot.host).remove(SlotId.unwrap(slotId)); + activeSlots[slot.host].remove(SlotId.unwrap(slotId)); uint256 amount = pricePerSlot(requests[requestId]); funds.sent += amount; funds.balance -= amount; diff --git a/test/Marketplace.test.js b/test/Marketplace.test.js index 738eeb8..e86008a 100644 --- a/test/Marketplace.test.js +++ b/test/Marketplace.test.js @@ -760,10 +760,31 @@ describe("Marketplace", function () { expect(await marketplace.mySlots()).to.deep.equal([slotId(slot1)]) }) - it("removes all slots from list when request fails", async function () { + it("removes all request's active slots when request fails", async function () { + // start first request await waitUntilStarted(marketplace, request, proof) + + // start a second request + let request2 = await exampleRequest() + request2.client = client.address + switchAccount(client) + await token.approve(marketplace.address, price(request2)) + await marketplace.requestStorage(request2) + switchAccount(host) + await waitUntilStarted(marketplace, request2, proof) + + // wait until first request fails await waitUntilFailed(marketplace, request, slot) - expect((await marketplace.mySlots())).to.deep.equal([]) + + // check that our active slots only contains slotIds from second request + let expected = [] + let expectedSlot = { ...slot, index: 0, request: requestId(request2) } + for (let i = 0; i < request2.ask.slots; i++) { + expectedSlot.index = i + let id = slotId(expectedSlot) + expected.push(id) + } + expect(await marketplace.mySlots()).to.deep.equal(expected.reverse()) }) it("removes slots from list when request finishes", async function () {