Change visibility of stake functions to 'internal'

This ensures that any contract that inherits from Stakes
doesn't expose its functions by default.
This commit is contained in:
Mark Spanbroek 2021-11-01 15:17:19 +01:00
parent d1f5ce0786
commit d5dede6e6b
3 changed files with 38 additions and 26 deletions

View File

@ -13,25 +13,25 @@ contract Stakes {
token = _token; token = _token;
} }
function stake(address account) public view returns (uint) { function _stake(address account) internal view returns (uint) {
return stakes[account]; return stakes[account];
} }
function increase(uint amount) public { function _increaseStake(uint amount) internal {
token.transferFrom(msg.sender, address(this), amount); token.transferFrom(msg.sender, address(this), amount);
stakes[msg.sender] += amount; stakes[msg.sender] += amount;
} }
function withdraw() public { function _withdrawStake() internal {
require(locks[msg.sender] == 0, "Stake locked"); require(locks[msg.sender] == 0, "Stake locked");
token.transfer(msg.sender, stakes[msg.sender]); token.transfer(msg.sender, stakes[msg.sender]);
} }
function _lock(address account) internal { function _lockStake(address account) internal {
locks[account] += 1; locks[account] += 1;
} }
function _unlock(address account) internal { function _unlockStake(address account) internal {
require(locks[account] > 0, "Stake already unlocked"); require(locks[account] > 0, "Stake already unlocked");
locks[account] -= 1; locks[account] -= 1;
} }

View File

@ -8,11 +8,23 @@ contract TestStakes is Stakes {
constructor(IERC20 token) Stakes(token) {} constructor(IERC20 token) Stakes(token) {}
function lock(address account) public { function stake(address account) public view returns (uint) {
_lock(account); return _stake(account);
} }
function unlock(address account) public { function increaseStake(uint amount) public {
_unlock(account); _increaseStake(amount);
}
function withdrawStake() public {
_withdrawStake();
}
function lockStake(address account) public {
_lockStake(account);
}
function unlockStake(address account) public {
_unlockStake(account);
} }
} }

View File

@ -23,49 +23,49 @@ describe("Stakes", function () {
it("increases stakes by transferring tokens", async function () { it("increases stakes by transferring tokens", async function () {
await token.approve(stakes.address, 20) await token.approve(stakes.address, 20)
await stakes.increase(20) await stakes.increaseStake(20)
let stake = await stakes.stake(host.address) let stake = await stakes.stake(host.address)
expect(stake).to.equal(20) expect(stake).to.equal(20)
}) })
it("does not increase stake when token transfer fails", async function () { it("does not increase stake when token transfer fails", async function () {
await expect( await expect(
stakes.increase(20) stakes.increaseStake(20)
).to.be.revertedWith("ERC20: transfer amount exceeds allowance") ).to.be.revertedWith("ERC20: transfer amount exceeds allowance")
}) })
it("allows withdrawal of stake", async function () { it("allows withdrawal of stake", async function () {
await token.approve(stakes.address, 20) await token.approve(stakes.address, 20)
await stakes.increase(20) await stakes.increaseStake(20)
let balanceBefore = await token.balanceOf(host.address) let balanceBefore = await token.balanceOf(host.address)
await stakes.withdraw() await stakes.withdrawStake()
let balanceAfter = await token.balanceOf(host.address) let balanceAfter = await token.balanceOf(host.address)
expect(balanceAfter - balanceBefore).to.equal(20) expect(balanceAfter - balanceBefore).to.equal(20)
}) })
it("locks stake", async function () { it("locks stake", async function () {
await token.approve(stakes.address, 20) await token.approve(stakes.address, 20)
await stakes.increase(20) await stakes.increaseStake(20)
await stakes.lock(host.address) await stakes.lockStake(host.address)
await expect(stakes.withdraw()).to.be.revertedWith("Stake locked") await expect(stakes.withdrawStake()).to.be.revertedWith("Stake locked")
await stakes.unlock(host.address) await stakes.unlockStake(host.address)
await expect(stakes.withdraw()).not.to.be.reverted await expect(stakes.withdrawStake()).not.to.be.reverted
}) })
it("fails to unlock when already unlocked", async function () { it("fails to unlock when already unlocked", async function () {
await expect( await expect(
stakes.unlock(host.address) stakes.unlockStake(host.address)
).to.be.revertedWith("Stake already unlocked") ).to.be.revertedWith("Stake already unlocked")
}) })
it("requires an equal amount of locks and unlocks", async function () { it("requires an equal amount of locks and unlocks", async function () {
await token.approve(stakes.address, 20) await token.approve(stakes.address, 20)
await stakes.increase(20) await stakes.increaseStake(20)
await stakes.lock(host.address) await stakes.lockStake(host.address)
await stakes.lock(host.address) await stakes.lockStake(host.address)
await stakes.unlock(host.address) await stakes.unlockStake(host.address)
await expect(stakes.withdraw()).to.be.revertedWith("Stake locked") await expect(stakes.withdrawStake()).to.be.revertedWith("Stake locked")
await stakes.unlock(host.address) await stakes.unlockStake(host.address)
await expect(stakes.withdraw()).not.to.be.reverted await expect(stakes.withdrawStake()).not.to.be.reverted
}) })
}) })