diff --git a/contracts/Collateral.sol b/contracts/Collateral.sol index cd29b65..4036026 100644 --- a/contracts/Collateral.sol +++ b/contracts/Collateral.sol @@ -5,11 +5,9 @@ import "@openzeppelin/contracts/token/ERC20/IERC20.sol"; contract Collateral { IERC20 private immutable token; + Totals private totals; mapping(address => uint256) private balances; - uint256 private totalDeposited; - uint256 private totalBalance; - constructor(IERC20 _token) invariant { token = _token; } @@ -20,13 +18,30 @@ contract Collateral { function deposit(uint256 amount) public invariant { token.transferFrom(msg.sender, address(this), amount); - totalDeposited += amount; + totals.deposited += amount; balances[msg.sender] += amount; - totalBalance += amount; + totals.balance += amount; + } + + function withdraw() public invariant { + uint256 amount = balances[msg.sender]; + balances[msg.sender] = 0; + totals.balance -= amount; + totals.withdrawn += amount; + assert(token.transfer(msg.sender, amount)); } modifier invariant() { + Totals memory oldTotals = totals; _; - assert(totalDeposited == totalBalance); + assert(totals.deposited >= oldTotals.deposited); + assert(totals.withdrawn >= oldTotals.withdrawn); + assert(totals.deposited == totals.balance + totals.withdrawn); + } + + struct Totals { + uint256 balance; + uint256 deposited; + uint256 withdrawn; } } diff --git a/test/Collateral.js b/test/Collateral.js index bec3885..1bdaed8 100644 --- a/test/Collateral.js +++ b/test/Collateral.js @@ -47,4 +47,30 @@ describe("Collateral", function () { ) }) }) + + describe("withdrawing", function () { + beforeEach(async function () { + await token.connect(account0).approve(collateral.address, 100) + await token.connect(account1).approve(collateral.address, 100) + await collateral.connect(account0).deposit(40) + await collateral.connect(account1).deposit(2) + }) + + it("updates the amount of collateral", async function () { + await collateral.connect(account0).withdraw() + expect(await collateral.balanceOf(account0.address)).to.equal(0) + expect(await collateral.balanceOf(account1.address)).to.equal(2) + await collateral.connect(account1).withdraw() + expect(await collateral.balanceOf(account0.address)).to.equal(0) + expect(await collateral.balanceOf(account1.address)).to.equal(0) + }) + + it("transfers balance to owner", async function () { + let balance = await collateral.balanceOf(account0.address) + let before = await token.balanceOf(account0.address) + await collateral.withdraw() + let after = await token.balanceOf(account0.address) + expect(after - before).to.equal(balance) + }) + }) })