diff --git a/contracts/Marketplace.sol b/contracts/Marketplace.sol index b706ce0..4c7e7f4 100644 --- a/contracts/Marketplace.sol +++ b/contracts/Marketplace.sol @@ -7,9 +7,11 @@ import "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; import "./Requests.sol"; import "./Collateral.sol"; import "./Proofs.sol"; +import "./StateRetrieval.sol"; -contract Marketplace is Collateral, Proofs { +contract Marketplace is Collateral, Proofs, StateRetrieval { using EnumerableSet for EnumerableSet.Bytes32Set; + using Requests for Request; uint256 public immutable collateral; uint256 public immutable minCollateralThreshold; @@ -20,8 +22,6 @@ contract Marketplace is Collateral, Proofs { mapping(RequestId => Request) private requests; mapping(RequestId => RequestContext) private requestContexts; mapping(SlotId => Slot) private slots; - mapping(address => EnumerableSet.Bytes32Set) private requestsPerClient; // purchasing - mapping(address => EnumerableSet.Bytes32Set) private slotsPerHost; // sales constructor( IERC20 _token, @@ -43,16 +43,8 @@ contract Marketplace is Collateral, Proofs { slashPercentage = _slashPercentage; } - function myRequests() public view returns (RequestId[] memory) { - return _toRequestIds(requestsPerClient[msg.sender].values()); - } - - function mySlots() public view returns (SlotId[] memory) { - return _toSlotIds(slotsPerHost[msg.sender].values()); - } - function isWithdrawAllowed() internal view override returns (bool) { - return slotsPerHost[msg.sender].length() == 0; + return !hasSlots(msg.sender); } function _equals(RequestId a, RequestId b) internal pure returns (bool) { @@ -64,7 +56,7 @@ contract Marketplace is Collateral, Proofs { ) public marketplaceInvariant { require(request.client == msg.sender, "Invalid client address"); - RequestId id = _toRequestId(request); + RequestId id = request.id(); require(requests[id].client == address(0), "Request already exists"); requests[id] = request; @@ -72,7 +64,7 @@ contract Marketplace is Collateral, Proofs { // set contract end time to `duration` from now (time request was created) context.endsAt = block.timestamp + request.ask.duration; - requestsPerClient[request.client].add(RequestId.unwrap(id)); + addToMyRequests(request.client, id); uint256 amount = price(request); funds.received += amount; @@ -104,7 +96,7 @@ contract Marketplace is Collateral, Proofs { RequestContext storage context = _context(requestId); context.slotsFilled += 1; - slotsPerHost[slot.host].add(SlotId.unwrap(slotId)); + addToMySlots(slot.host, slotId); emit SlotFilled(requestId, slotIndex, slotId); if (context.slotsFilled == request.ask.slots) { @@ -121,7 +113,7 @@ contract Marketplace is Collateral, Proofs { if (s == RequestState.Finished || s == RequestState.Cancelled) { payoutSlot(slot.requestId, slotId); } else if (s == RequestState.Failed) { - slotsPerHost[msg.sender].remove(SlotId.unwrap(slotId)); + removeFromMySlots(msg.sender, slotId); } else { _forciblyFreeSlot(slotId); } @@ -164,7 +156,7 @@ contract Marketplace is Collateral, Proofs { _stopRequiringProofs(slotId); - slotsPerHost[slot.host].remove(SlotId.unwrap(slotId)); + removeFromMySlots(slot.host, slotId); slot.host = address(0); slot.requestId = RequestId.wrap(0); @@ -198,11 +190,11 @@ contract Marketplace is Collateral, Proofs { RequestContext storage context = _context(requestId); Request storage request = _request(requestId); context.state = RequestState.Finished; - requestsPerClient[request.client].remove(RequestId.unwrap(requestId)); + removeFromMyRequests(request.client, requestId); Slot storage slot = _slot(slotId); require(!slot.hostPaid, "Already paid"); - slotsPerHost[slot.host].remove(SlotId.unwrap(slotId)); + removeFromMySlots(slot.host, slotId); uint256 amount = pricePerSlot(requests[requestId]); funds.sent += amount; @@ -224,7 +216,7 @@ contract Marketplace is Collateral, Proofs { // Update request state to Cancelled. Handle in the withdraw transaction // as there needs to be someone to pay for the gas to update the state context.state = RequestState.Cancelled; - requestsPerClient[request.client].remove(RequestId.unwrap(requestId)); + removeFromMyRequests(request.client, requestId); emit RequestCancelled(requestId); @@ -407,30 +399,6 @@ contract Marketplace is Collateral, Proofs { return s == RequestState.New || s == RequestState.Started; } - function _toRequestId( - Request memory request - ) internal pure returns (RequestId) { - return RequestId.wrap(keccak256(abi.encode(request))); - } - - function _toRequestIds( - bytes32[] memory array - ) private pure returns (RequestId[] memory result) { - // solhint-disable-next-line no-inline-assembly - assembly { - result := array - } - } - - function _toSlotIds( - bytes32[] memory array - ) private pure returns (SlotId[] memory result) { - // solhint-disable-next-line no-inline-assembly - assembly { - result := array - } - } - function _toBytes32s( RequestId[] memory array ) private pure returns (bytes32[] memory result) { diff --git a/contracts/Requests.sol b/contracts/Requests.sol index 1ba6dd1..250ff29 100644 --- a/contracts/Requests.sol +++ b/contracts/Requests.sol @@ -36,3 +36,27 @@ struct PoR { bytes publicKey; // public key bytes name; // random name } + +library Requests { + function id(Request memory request) internal pure returns (RequestId) { + return RequestId.wrap(keccak256(abi.encode(request))); + } + + function toRequestIds( + bytes32[] memory ids + ) internal pure returns (RequestId[] memory result) { + // solhint-disable-next-line no-inline-assembly + assembly { + result := ids + } + } + + function toSlotIds( + bytes32[] memory ids + ) internal pure returns (SlotId[] memory result) { + // solhint-disable-next-line no-inline-assembly + assembly { + result := ids + } + } +} diff --git a/contracts/StateRetrieval.sol b/contracts/StateRetrieval.sol new file mode 100644 index 0000000..8204e2f --- /dev/null +++ b/contracts/StateRetrieval.sol @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.8; + +import "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; +import "./Requests.sol"; + +contract StateRetrieval { + using EnumerableSet for EnumerableSet.Bytes32Set; + using Requests for bytes32[]; + + mapping(address => EnumerableSet.Bytes32Set) private requestsPerClient; + mapping(address => EnumerableSet.Bytes32Set) private slotsPerHost; + + function myRequests() public view returns (RequestId[] memory) { + return requestsPerClient[msg.sender].values().toRequestIds(); + } + + function mySlots() public view returns (SlotId[] memory) { + return slotsPerHost[msg.sender].values().toSlotIds(); + } + + function hasSlots(address host) internal view returns (bool) { + return slotsPerHost[host].length() > 0; + } + + function addToMyRequests(address client, RequestId requestId) internal { + requestsPerClient[client].add(RequestId.unwrap(requestId)); + } + + function addToMySlots(address host, SlotId slotId) internal { + slotsPerHost[host].add(SlotId.unwrap(slotId)); + } + + function removeFromMyRequests(address client, RequestId requestId) internal { + requestsPerClient[client].remove(RequestId.unwrap(requestId)); + } + + function removeFromMySlots(address host, SlotId slotId) internal { + slotsPerHost[host].remove(SlotId.unwrap(slotId)); + } +}