diff --git a/contracts/Stakes.sol b/contracts/Stakes.sol index d6b2533..00352a3 100644 --- a/contracts/Stakes.sol +++ b/contracts/Stakes.sol @@ -13,25 +13,25 @@ contract Stakes { token = _token; } - function stake(address account) public view returns (uint) { + function _stake(address account) internal view returns (uint) { return stakes[account]; } - function increase(uint amount) public { + function _increaseStake(uint amount) internal { token.transferFrom(msg.sender, address(this), amount); stakes[msg.sender] += amount; } - function withdraw() public { + function _withdrawStake() internal { require(locks[msg.sender] == 0, "Stake locked"); token.transfer(msg.sender, stakes[msg.sender]); } - function _lock(address account) internal { + function _lockStake(address account) internal { locks[account] += 1; } - function _unlock(address account) internal { + function _unlockStake(address account) internal { require(locks[account] > 0, "Stake already unlocked"); locks[account] -= 1; } diff --git a/contracts/TestStakes.sol b/contracts/TestStakes.sol index 61d5c13..6041162 100644 --- a/contracts/TestStakes.sol +++ b/contracts/TestStakes.sol @@ -8,11 +8,23 @@ contract TestStakes is Stakes { constructor(IERC20 token) Stakes(token) {} - function lock(address account) public { - _lock(account); + function stake(address account) public view returns (uint) { + return _stake(account); } - function unlock(address account) public { - _unlock(account); + function increaseStake(uint amount) public { + _increaseStake(amount); + } + + function withdrawStake() public { + _withdrawStake(); + } + + function lockStake(address account) public { + _lockStake(account); + } + + function unlockStake(address account) public { + _unlockStake(account); } } diff --git a/test/Stakes.test.js b/test/Stakes.test.js index a5b0619..6458132 100644 --- a/test/Stakes.test.js +++ b/test/Stakes.test.js @@ -23,49 +23,49 @@ describe("Stakes", function () { it("increases stakes by transferring tokens", async function () { await token.approve(stakes.address, 20) - await stakes.increase(20) + await stakes.increaseStake(20) let stake = await stakes.stake(host.address) expect(stake).to.equal(20) }) it("does not increase stake when token transfer fails", async function () { await expect( - stakes.increase(20) + stakes.increaseStake(20) ).to.be.revertedWith("ERC20: transfer amount exceeds allowance") }) it("allows withdrawal of stake", async function () { await token.approve(stakes.address, 20) - await stakes.increase(20) + await stakes.increaseStake(20) let balanceBefore = await token.balanceOf(host.address) - await stakes.withdraw() + await stakes.withdrawStake() let balanceAfter = await token.balanceOf(host.address) expect(balanceAfter - balanceBefore).to.equal(20) }) it("locks stake", async function () { await token.approve(stakes.address, 20) - await stakes.increase(20) - await stakes.lock(host.address) - await expect(stakes.withdraw()).to.be.revertedWith("Stake locked") - await stakes.unlock(host.address) - await expect(stakes.withdraw()).not.to.be.reverted + await stakes.increaseStake(20) + await stakes.lockStake(host.address) + await expect(stakes.withdrawStake()).to.be.revertedWith("Stake locked") + await stakes.unlockStake(host.address) + await expect(stakes.withdrawStake()).not.to.be.reverted }) it("fails to unlock when already unlocked", async function () { await expect( - stakes.unlock(host.address) + stakes.unlockStake(host.address) ).to.be.revertedWith("Stake already unlocked") }) it("requires an equal amount of locks and unlocks", async function () { await token.approve(stakes.address, 20) - await stakes.increase(20) - await stakes.lock(host.address) - await stakes.lock(host.address) - await stakes.unlock(host.address) - await expect(stakes.withdraw()).to.be.revertedWith("Stake locked") - await stakes.unlock(host.address) - await expect(stakes.withdraw()).not.to.be.reverted + await stakes.increaseStake(20) + await stakes.lockStake(host.address) + await stakes.lockStake(host.address) + await stakes.unlockStake(host.address) + await expect(stakes.withdrawStake()).to.be.revertedWith("Stake locked") + await stakes.unlockStake(host.address) + await expect(stakes.withdrawStake()).not.to.be.reverted }) })