diff --git a/contracts/Stakes.sol b/contracts/Stakes.sol new file mode 100644 index 0000000..d6b2533 --- /dev/null +++ b/contracts/Stakes.sol @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.0; + +import "@openzeppelin/contracts/token/ERC20/IERC20.sol"; + +contract Stakes { + + IERC20 private token; + mapping(address=>uint) private stakes; + mapping(address=>uint) private locks; + + constructor(IERC20 _token) { + token = _token; + } + + function stake(address account) public view returns (uint) { + return stakes[account]; + } + + function increase(uint amount) public { + token.transferFrom(msg.sender, address(this), amount); + stakes[msg.sender] += amount; + } + + function withdraw() public { + require(locks[msg.sender] == 0, "Stake locked"); + token.transfer(msg.sender, stakes[msg.sender]); + } + + function _lock(address account) internal { + locks[account] += 1; + } + + function _unlock(address account) internal { + require(locks[account] > 0, "Stake already unlocked"); + locks[account] -= 1; + } +} diff --git a/contracts/TestStakes.sol b/contracts/TestStakes.sol new file mode 100644 index 0000000..61d5c13 --- /dev/null +++ b/contracts/TestStakes.sol @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.0; + +import "./Stakes.sol"; + +// exposes internal functions of Stakes for testing +contract TestStakes is Stakes { + + constructor(IERC20 token) Stakes(token) {} + + function lock(address account) public { + _lock(account); + } + + function unlock(address account) public { + _unlock(account); + } +} diff --git a/contracts/TestToken.sol b/contracts/TestToken.sol new file mode 100644 index 0000000..9d59ecc --- /dev/null +++ b/contracts/TestToken.sol @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.0; + +import "@openzeppelin/contracts/token/ERC20/ERC20.sol"; + +contract TestToken is ERC20 { + constructor() ERC20("TestToken", "TST") { + _mint(msg.sender, 1000); + } +} diff --git a/test/Stakes.test.js b/test/Stakes.test.js new file mode 100644 index 0000000..a5b0619 --- /dev/null +++ b/test/Stakes.test.js @@ -0,0 +1,71 @@ +const { expect } = require("chai") +const { ethers } = require("hardhat") + +describe("Stakes", function () { + + var stakes + var token + var host + + beforeEach(async function() { + [host] = await ethers.getSigners() + const Stakes = await ethers.getContractFactory("TestStakes") + const TestToken = await ethers.getContractFactory("TestToken") + token = await TestToken.deploy() + stakes = await Stakes.deploy(token.address) + }) + + it("has zero stakes initially", async function () { + const address = await host.getAddress() + const stake = await stakes.stake(address) + expect(stake).to.equal(0) + }) + + it("increases stakes by transferring tokens", async function () { + await token.approve(stakes.address, 20) + await stakes.increase(20) + let stake = await stakes.stake(host.address) + expect(stake).to.equal(20) + }) + + it("does not increase stake when token transfer fails", async function () { + await expect( + stakes.increase(20) + ).to.be.revertedWith("ERC20: transfer amount exceeds allowance") + }) + + it("allows withdrawal of stake", async function () { + await token.approve(stakes.address, 20) + await stakes.increase(20) + let balanceBefore = await token.balanceOf(host.address) + await stakes.withdraw() + let balanceAfter = await token.balanceOf(host.address) + expect(balanceAfter - balanceBefore).to.equal(20) + }) + + it("locks stake", async function () { + await token.approve(stakes.address, 20) + await stakes.increase(20) + await stakes.lock(host.address) + await expect(stakes.withdraw()).to.be.revertedWith("Stake locked") + await stakes.unlock(host.address) + await expect(stakes.withdraw()).not.to.be.reverted + }) + + it("fails to unlock when already unlocked", async function () { + await expect( + stakes.unlock(host.address) + ).to.be.revertedWith("Stake already unlocked") + }) + + it("requires an equal amount of locks and unlocks", async function () { + await token.approve(stakes.address, 20) + await stakes.increase(20) + await stakes.lock(host.address) + await stakes.lock(host.address) + await stakes.unlock(host.address) + await expect(stakes.withdraw()).to.be.revertedWith("Stake locked") + await stakes.unlock(host.address) + await expect(stakes.withdraw()).not.to.be.reverted + }) +})