diff --git a/.gas-snapshot b/.gas-snapshot index 383ec55..763d688 100644 --- a/.gas-snapshot +++ b/.gas-snapshot @@ -7,5 +7,6 @@ MigrateTest:test_RevertWhen_SenderIsNotVault() (gas: 10629) StakeManagerTest:testDeployment() (gas: 26172) StakeTest:testDeployment() (gas: 26172) StakeTest:test_RevertWhen_SenderIsNotVault() (gas: 10638) -UnstakeTest:testDeployment() (gas: 26172) -UnstakeTest:test_RevertWhen_SenderIsNotVault() (gas: 10575) \ No newline at end of file +UnstakeTest:testDeployment() (gas: 26355) +UnstakeTest:test_RevertWhen_FundsLocked() (gas: 973435) +UnstakeTest:test_RevertWhen_SenderIsNotVault() (gas: 10609) \ No newline at end of file diff --git a/contracts/StakeManager.sol b/contracts/StakeManager.sol index 1609857..22e88ce 100644 --- a/contracts/StakeManager.sol +++ b/contracts/StakeManager.sol @@ -8,6 +8,7 @@ import { StakeVault } from "./StakeVault.sol"; contract StakeManager is Ownable { error StakeManager__SenderIsNotVault(); + error StakeManager__FundsLocked(); struct Account { uint256 lockUntil; @@ -74,7 +75,9 @@ contract StakeManager is Ownable { */ function unstake(uint256 _amount) external onlyVault { Account storage account = accounts[msg.sender]; - require(account.lockUntil <= block.timestamp, "Funds are locked"); + if (account.lockUntil > block.timestamp) { + revert StakeManager__FundsLocked(); + } processAccount(account, currentEpoch); uint256 reducedMultiplier = (_amount * account.multiplier) / account.balance; account.multiplier -= reducedMultiplier; diff --git a/test/StakeManager.t.sol b/test/StakeManager.t.sol index 14e5ca3..5280872 100644 --- a/test/StakeManager.t.sol +++ b/test/StakeManager.t.sol @@ -1,10 +1,13 @@ // SPDX-License-Identifier: UNLICENSED pragma solidity ^0.8.19; +import { ERC20 } from "@openzeppelin/contracts/token/ERC20/ERC20.sol"; + import { Test } from "forge-std/Test.sol"; import { Deploy } from "../script/Deploy.s.sol"; import { DeploymentConfig } from "../script/DeploymentConfig.s.sol"; import { StakeManager } from "../contracts/StakeManager.sol"; +import { StakeVault } from "../contracts/StakeVault.sol"; contract StakeManagerTest is Test { DeploymentConfig internal deploymentConfig; @@ -12,6 +15,7 @@ contract StakeManagerTest is Test { address internal stakeToken; address internal deployer; + address internal testUser = makeAddr("testUser"); function setUp() public virtual { Deploy deployment = new Deploy(); @@ -29,6 +33,14 @@ contract StakeManagerTest is Test { assertEq(address(stakeManager.oldManager()), address(0)); assertEq(stakeManager.totalSupply(), 0); } + + function _createTestVault(address owner) internal returns (StakeVault vault) { + vm.prank(owner); + vault = new StakeVault(owner, ERC20(stakeToken), stakeManager); + + vm.prank(deployer); + stakeManager.setVault(address(vault).codehash); + } } contract StakeTest is StakeManagerTest { @@ -51,6 +63,21 @@ contract UnstakeTest is StakeManagerTest { vm.expectRevert(StakeManager.StakeManager__SenderIsNotVault.selector); stakeManager.unstake(100); } + + function test_RevertWhen_FundsLocked() public { + // ensure user has funds + deal(stakeToken, testUser, 1000); + StakeVault userVault = _createTestVault(testUser); + + vm.startPrank(testUser); + ERC20(stakeToken).approve(address(userVault), 100); + + uint256 lockTime = 1 days; + userVault.stake(100, lockTime); + + vm.expectRevert(StakeManager.StakeManager__FundsLocked.selector); + userVault.unstake(100); + } } contract LockTest is StakeManagerTest {