diff --git a/contracts/Collateral.sol b/contracts/Collateral.sol index 7ad8a3a..efca0da 100644 --- a/contracts/Collateral.sol +++ b/contracts/Collateral.sol @@ -6,10 +6,11 @@ import "./AccountLocks.sol"; contract Collateral is AccountLocks { IERC20 public immutable token; - Totals private totals; + CollateralFunds private funds; + mapping(address => uint256) private balances; - constructor(IERC20 _token) invariant { + constructor(IERC20 _token) collateralInvariant { token = _token; } @@ -19,51 +20,52 @@ contract Collateral is AccountLocks { function add(address account, uint256 amount) private { balances[account] += amount; - totals.balance += amount; + funds.balance += amount; } function subtract(address account, uint256 amount) private { balances[account] -= amount; - totals.balance -= amount; + funds.balance -= amount; } - function transferFrom(address sender, uint256 amount) private { + function transferFrom(address sender, uint256 amount) internal { address receiver = address(this); require(token.transferFrom(sender, receiver, amount), "Transfer failed"); } - function deposit(uint256 amount) public invariant { + function deposit(uint256 amount) public collateralInvariant { transferFrom(msg.sender, amount); - totals.deposited += amount; + funds.deposited += amount; add(msg.sender, amount); } - function withdraw() public invariant { + function withdraw() public collateralInvariant { _unlockAccount(); uint256 amount = balanceOf(msg.sender); - totals.withdrawn += amount; + funds.withdrawn += amount; subtract(msg.sender, amount); assert(token.transfer(msg.sender, amount)); } - function _slash(address account, uint256 percentage) internal invariant { + function _slash(address account, uint256 percentage) + internal + collateralInvariant + { uint256 amount = (balanceOf(account) * percentage) / 100; - totals.slashed += amount; + funds.slashed += amount; subtract(account, amount); } - modifier invariant() { - Totals memory oldTotals = totals; + modifier collateralInvariant() { + CollateralFunds memory oldFunds = funds; _; - assert(totals.deposited >= oldTotals.deposited); - assert(totals.withdrawn >= oldTotals.withdrawn); - assert(totals.slashed >= oldTotals.slashed); - assert( - totals.deposited == totals.balance + totals.withdrawn + totals.slashed - ); + assert(funds.deposited >= oldFunds.deposited); + assert(funds.withdrawn >= oldFunds.withdrawn); + assert(funds.slashed >= oldFunds.slashed); + assert(funds.deposited == funds.balance + funds.withdrawn + funds.slashed); } - struct Totals { + struct CollateralFunds { uint256 balance; uint256 deposited; uint256 withdrawn; diff --git a/contracts/Marketplace.sol b/contracts/Marketplace.sol index 3629151..cb27b1c 100644 --- a/contracts/Marketplace.sol +++ b/contracts/Marketplace.sol @@ -5,10 +5,10 @@ import "@openzeppelin/contracts/token/ERC20/IERC20.sol"; contract Marketplace { IERC20 public immutable token; - Totals private totals; + MarketplaceFunds private funds; mapping(bytes32 => Request) private requests; - constructor(IERC20 _token) invariant { + constructor(IERC20 _token) marketplaceInvariant { token = _token; } @@ -17,14 +17,17 @@ contract Marketplace { require(token.transferFrom(sender, receiver, amount), "Transfer failed"); } - function requestStorage(Request calldata request) public invariant { + function requestStorage(Request calldata request) + public + marketplaceInvariant + { bytes32 id = keccak256(abi.encode(request)); require(request.size > 0, "Invalid size"); require(requests[id].size == 0, "Request already exists"); requests[id] = request; transferFrom(msg.sender, request.maxPrice); - totals.received += request.maxPrice; - totals.balance += request.maxPrice; + funds.received += request.maxPrice; + funds.balance += request.maxPrice; emit StorageRequested(id, request); } @@ -40,15 +43,15 @@ contract Marketplace { event StorageRequested(bytes32 id, Request request); - modifier invariant() { - Totals memory oldTotals = totals; + modifier marketplaceInvariant() { + MarketplaceFunds memory oldFunds = funds; _; - assert(totals.received >= oldTotals.received); - assert(totals.sent >= oldTotals.sent); - assert(totals.received == totals.balance + totals.sent); + assert(funds.received >= oldFunds.received); + assert(funds.sent >= oldFunds.sent); + assert(funds.received == funds.balance + funds.sent); } - struct Totals { + struct MarketplaceFunds { uint256 balance; uint256 received; uint256 sent;