diff --git a/contracts/Vault.sol b/contracts/Vault.sol index a67f630..a96cc54 100644 --- a/contracts/Vault.sol +++ b/contracts/Vault.sol @@ -1,43 +1,18 @@ // SPDX-License-Identifier: MIT pragma solidity 0.8.28; -import "@openzeppelin/contracts/token/ERC20/IERC20.sol"; -import "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol"; -import "./Timestamps.sol"; +import "./VaultBase.sol"; -using SafeERC20 for IERC20; -using Timestamps for Timestamp; - -contract Vault { - IERC20 private immutable _token; - - type Controller is address; - type Context is bytes32; - type Recipient is address; - - struct Lock { - Timestamp expiry; - Timestamp maximum; - } - - mapping(Controller => mapping(Context => Lock)) private _locks; - mapping(Controller => mapping(Context => mapping(Recipient => uint256))) - private _available; - mapping(Controller => mapping(Context => mapping(Recipient => uint256))) - private _designated; - - constructor(IERC20 token) { - _token = token; - } +contract Vault is VaultBase { + // solhint-disable-next-line no-empty-blocks + constructor(IERC20 token) VaultBase(token) {} function balance( Context context, Recipient recipient ) public view returns (uint256) { Controller controller = Controller.wrap(msg.sender); - return - _available[controller][context][recipient] + - _designated[controller][context][recipient]; + return _getBalance(controller, context, recipient); } function designated( @@ -45,40 +20,32 @@ contract Vault { Recipient recipient ) public view returns (uint256) { Controller controller = Controller.wrap(msg.sender); - return _designated[controller][context][recipient]; + return _getDesignated(controller, context, recipient); } function lock(Context context) public view returns (Lock memory) { Controller controller = Controller.wrap(msg.sender); - return _locks[controller][context]; + return _getLock(controller, context); } function deposit(Context context, address from, uint256 amount) public { Controller controller = Controller.wrap(msg.sender); - Recipient recipient = Recipient.wrap(from); - _available[controller][context][recipient] += amount; - _token.safeTransferFrom(from, address(this), amount); - } - - function _delete(Context context, Recipient recipient) private { - Controller controller = Controller.wrap(msg.sender); - delete _available[controller][context][recipient]; - delete _designated[controller][context][recipient]; + _deposit(controller, context, from, amount); } function withdraw(Context context, Recipient recipient) public { Controller controller = Controller.wrap(msg.sender); - require(!lock(context).expiry.isFuture(), Locked()); - delete _locks[controller][context]; - uint256 amount = balance(context, recipient); - _delete(context, recipient); - _token.safeTransfer(Recipient.unwrap(recipient), amount); + _withdraw(controller, context, recipient); + } + + function withdrawByRecipient(Controller controller, Context context) public { + Recipient recipient = Recipient.wrap(msg.sender); + _withdraw(controller, context, recipient); } function burn(Context context, Recipient recipient) public { - uint256 amount = balance(context, recipient); - _delete(context, recipient); - _token.safeTransfer(address(0xdead), amount); + Controller controller = Controller.wrap(msg.sender); + _burn(controller, context, recipient); } function transfer( @@ -88,12 +55,7 @@ contract Vault { uint256 amount ) public { Controller controller = Controller.wrap(msg.sender); - require( - amount <= _available[controller][context][from], - InsufficientBalance() - ); - _available[controller][context][from] -= amount; - _available[controller][context][to] += amount; + _transfer(controller, context, from, to, amount); } function designate( @@ -102,34 +64,16 @@ contract Vault { uint256 amount ) public { Controller controller = Controller.wrap(msg.sender); - require( - amount <= _available[controller][context][recipient], - InsufficientBalance() - ); - _available[controller][context][recipient] -= amount; - _designated[controller][context][recipient] += amount; + _designate(controller, context, recipient, amount); } function lockup(Context context, Timestamp expiry, Timestamp maximum) public { - require(Timestamp.unwrap(lock(context).maximum) == 0, AlreadyLocked()); - require(!expiry.isAfter(maximum), ExpiryPastMaximum()); Controller controller = Controller.wrap(msg.sender); - _locks[controller][context] = Lock({expiry: expiry, maximum: maximum}); + _lockup(controller, context, expiry, maximum); } function extend(Context context, Timestamp expiry) public { - Lock memory previous = lock(context); - require(previous.expiry.isFuture(), LockExpired()); - require(!previous.expiry.isAfter(expiry), InvalidExpiry()); - require(!expiry.isAfter(previous.maximum), ExpiryPastMaximum()); Controller controller = Controller.wrap(msg.sender); - _locks[controller][context].expiry = expiry; + _extendLock(controller, context, expiry); } - - error InsufficientBalance(); - error Locked(); - error AlreadyLocked(); - error ExpiryPastMaximum(); - error InvalidExpiry(); - error LockExpired(); } diff --git a/contracts/VaultBase.sol b/contracts/VaultBase.sol new file mode 100644 index 0000000..1a76f1a --- /dev/null +++ b/contracts/VaultBase.sol @@ -0,0 +1,161 @@ +// SPDX-License-Identifier: MIT +pragma solidity 0.8.28; + +import "@openzeppelin/contracts/token/ERC20/IERC20.sol"; +import "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol"; +import "./Timestamps.sol"; + +using SafeERC20 for IERC20; +using Timestamps for Timestamp; + +abstract contract VaultBase { + IERC20 internal immutable _token; + + type Controller is address; + type Context is bytes32; + type Recipient is address; + + struct Lock { + Timestamp expiry; + Timestamp maximum; + } + + mapping(Controller => mapping(Context => Lock)) private _locks; + mapping(Controller => mapping(Context => mapping(Recipient => uint256))) + private _available; + mapping(Controller => mapping(Context => mapping(Recipient => uint256))) + private _designated; + + constructor(IERC20 token) { + _token = token; + } + + function _getBalance( + Controller controller, + Context context, + Recipient recipient + ) internal view returns (uint256) { + return + _available[controller][context][recipient] + + _designated[controller][context][recipient]; + } + + function _getDesignated( + Controller controller, + Context context, + Recipient recipient + ) internal view returns (uint256) { + return _designated[controller][context][recipient]; + } + + function _getLock( + Controller controller, + Context context + ) internal view returns (Lock memory) { + return _locks[controller][context]; + } + + function _deposit( + Controller controller, + Context context, + address from, + uint256 amount + ) internal { + Recipient recipient = Recipient.wrap(from); + _available[controller][context][recipient] += amount; + _token.safeTransferFrom(from, address(this), amount); + } + + function _delete( + Controller controller, + Context context, + Recipient recipient + ) private { + delete _available[controller][context][recipient]; + delete _designated[controller][context][recipient]; + } + + function _withdraw( + Controller controller, + Context context, + Recipient recipient + ) internal { + require(!_getLock(controller, context).expiry.isFuture(), Locked()); + delete _locks[controller][context]; + uint256 amount = _getBalance(controller, context, recipient); + _delete(controller, context, recipient); + _token.safeTransfer(Recipient.unwrap(recipient), amount); + } + + function _burn( + Controller controller, + Context context, + Recipient recipient + ) internal { + uint256 amount = _getBalance(controller, context, recipient); + _delete(controller, context, recipient); + _token.safeTransfer(address(0xdead), amount); + } + + function _transfer( + Controller controller, + Context context, + Recipient from, + Recipient to, + uint256 amount + ) internal { + require( + amount <= _available[controller][context][from], + InsufficientBalance() + ); + _available[controller][context][from] -= amount; + _available[controller][context][to] += amount; + } + + function _designate( + Controller controller, + Context context, + Recipient recipient, + uint256 amount + ) internal { + require( + amount <= _available[controller][context][recipient], + InsufficientBalance() + ); + _available[controller][context][recipient] -= amount; + _designated[controller][context][recipient] += amount; + } + + function _lockup( + Controller controller, + Context context, + Timestamp expiry, + Timestamp maximum + ) internal { + require( + Timestamp.unwrap(_getLock(controller, context).maximum) == 0, + AlreadyLocked() + ); + require(!expiry.isAfter(maximum), ExpiryPastMaximum()); + _locks[controller][context] = Lock({expiry: expiry, maximum: maximum}); + } + + function _extendLock( + Controller controller, + Context context, + Timestamp expiry + ) internal { + Lock memory previous = _getLock(controller, context); + require(previous.expiry.isFuture(), LockExpired()); + require(!previous.expiry.isAfter(expiry), InvalidExpiry()); + require(!expiry.isAfter(previous.maximum), ExpiryPastMaximum()); + _locks[controller][context].expiry = expiry; + } + + error InsufficientBalance(); + error Locked(); + error AlreadyLocked(); + error ExpiryPastMaximum(); + error InvalidExpiry(); + error LockExpired(); +} diff --git a/test/Vault.tests.js b/test/Vault.tests.js index c73ae0b..2ee8cc7 100644 --- a/test/Vault.tests.js +++ b/test/Vault.tests.js @@ -6,6 +6,7 @@ const { currentTime, advanceTimeToForNextBlock } = require("./evm") describe("Vault", function () { let token let vault + let controller let account, account2, account3 beforeEach(async function () { @@ -13,7 +14,7 @@ describe("Vault", function () { token = await TestToken.deploy() const Vault = await ethers.getContractFactory("Vault") vault = await Vault.deploy(token.address) - ;[, account, account2, account3] = await ethers.getSigners() + ;[controller, account, account2, account3] = await ethers.getSigners() await token.mint(account.address, 1_000_000) await token.mint(account2.address, 1_000_000) await token.mint(account3.address, 1_000_000) @@ -79,13 +80,22 @@ describe("Vault", function () { await vault.deposit(context, account.address, amount) }) - it("can withdraw a deposit", async function () { + it("allows controller to withdraw for a recipient", async function () { const before = await token.balanceOf(account.address) await vault.withdraw(context, account.address) const after = await token.balanceOf(account.address) expect(after - before).to.equal(amount) }) + it("allows recipient to withdraw for itself", async function () { + const before = await token.balanceOf(account.address) + await vault + .connect(account) + .withdrawByRecipient(controller.address, context) + const after = await token.balanceOf(account.address) + expect(after - before).to.equal(amount) + }) + it("empties the balance when withdrawing", async function () { await vault.withdraw(context, account.address) expect(await vault.balance(context, account.address)).to.equal(0)