vault: allow recipient to withdraw

This commit is contained in:
Mark Spanbroek 2025-01-22 11:32:22 +01:00
parent 834255c871
commit 922121e659
3 changed files with 193 additions and 78 deletions

View File

@ -1,43 +1,18 @@
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
pragma solidity 0.8.28; pragma solidity 0.8.28;
import "@openzeppelin/contracts/token/ERC20/IERC20.sol"; import "./VaultBase.sol";
import "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol";
import "./Timestamps.sol";
using SafeERC20 for IERC20; contract Vault is VaultBase {
using Timestamps for Timestamp; // solhint-disable-next-line no-empty-blocks
constructor(IERC20 token) VaultBase(token) {}
contract Vault {
IERC20 private immutable _token;
type Controller is address;
type Context is bytes32;
type Recipient is address;
struct Lock {
Timestamp expiry;
Timestamp maximum;
}
mapping(Controller => mapping(Context => Lock)) private _locks;
mapping(Controller => mapping(Context => mapping(Recipient => uint256)))
private _available;
mapping(Controller => mapping(Context => mapping(Recipient => uint256)))
private _designated;
constructor(IERC20 token) {
_token = token;
}
function balance( function balance(
Context context, Context context,
Recipient recipient Recipient recipient
) public view returns (uint256) { ) public view returns (uint256) {
Controller controller = Controller.wrap(msg.sender); Controller controller = Controller.wrap(msg.sender);
return return _getBalance(controller, context, recipient);
_available[controller][context][recipient] +
_designated[controller][context][recipient];
} }
function designated( function designated(
@ -45,40 +20,32 @@ contract Vault {
Recipient recipient Recipient recipient
) public view returns (uint256) { ) public view returns (uint256) {
Controller controller = Controller.wrap(msg.sender); Controller controller = Controller.wrap(msg.sender);
return _designated[controller][context][recipient]; return _getDesignated(controller, context, recipient);
} }
function lock(Context context) public view returns (Lock memory) { function lock(Context context) public view returns (Lock memory) {
Controller controller = Controller.wrap(msg.sender); Controller controller = Controller.wrap(msg.sender);
return _locks[controller][context]; return _getLock(controller, context);
} }
function deposit(Context context, address from, uint256 amount) public { function deposit(Context context, address from, uint256 amount) public {
Controller controller = Controller.wrap(msg.sender); Controller controller = Controller.wrap(msg.sender);
Recipient recipient = Recipient.wrap(from); _deposit(controller, context, from, amount);
_available[controller][context][recipient] += amount;
_token.safeTransferFrom(from, address(this), amount);
}
function _delete(Context context, Recipient recipient) private {
Controller controller = Controller.wrap(msg.sender);
delete _available[controller][context][recipient];
delete _designated[controller][context][recipient];
} }
function withdraw(Context context, Recipient recipient) public { function withdraw(Context context, Recipient recipient) public {
Controller controller = Controller.wrap(msg.sender); Controller controller = Controller.wrap(msg.sender);
require(!lock(context).expiry.isFuture(), Locked()); _withdraw(controller, context, recipient);
delete _locks[controller][context]; }
uint256 amount = balance(context, recipient);
_delete(context, recipient); function withdrawByRecipient(Controller controller, Context context) public {
_token.safeTransfer(Recipient.unwrap(recipient), amount); Recipient recipient = Recipient.wrap(msg.sender);
_withdraw(controller, context, recipient);
} }
function burn(Context context, Recipient recipient) public { function burn(Context context, Recipient recipient) public {
uint256 amount = balance(context, recipient); Controller controller = Controller.wrap(msg.sender);
_delete(context, recipient); _burn(controller, context, recipient);
_token.safeTransfer(address(0xdead), amount);
} }
function transfer( function transfer(
@ -88,12 +55,7 @@ contract Vault {
uint256 amount uint256 amount
) public { ) public {
Controller controller = Controller.wrap(msg.sender); Controller controller = Controller.wrap(msg.sender);
require( _transfer(controller, context, from, to, amount);
amount <= _available[controller][context][from],
InsufficientBalance()
);
_available[controller][context][from] -= amount;
_available[controller][context][to] += amount;
} }
function designate( function designate(
@ -102,34 +64,16 @@ contract Vault {
uint256 amount uint256 amount
) public { ) public {
Controller controller = Controller.wrap(msg.sender); Controller controller = Controller.wrap(msg.sender);
require( _designate(controller, context, recipient, amount);
amount <= _available[controller][context][recipient],
InsufficientBalance()
);
_available[controller][context][recipient] -= amount;
_designated[controller][context][recipient] += amount;
} }
function lockup(Context context, Timestamp expiry, Timestamp maximum) public { function lockup(Context context, Timestamp expiry, Timestamp maximum) public {
require(Timestamp.unwrap(lock(context).maximum) == 0, AlreadyLocked());
require(!expiry.isAfter(maximum), ExpiryPastMaximum());
Controller controller = Controller.wrap(msg.sender); Controller controller = Controller.wrap(msg.sender);
_locks[controller][context] = Lock({expiry: expiry, maximum: maximum}); _lockup(controller, context, expiry, maximum);
} }
function extend(Context context, Timestamp expiry) public { function extend(Context context, Timestamp expiry) public {
Lock memory previous = lock(context);
require(previous.expiry.isFuture(), LockExpired());
require(!previous.expiry.isAfter(expiry), InvalidExpiry());
require(!expiry.isAfter(previous.maximum), ExpiryPastMaximum());
Controller controller = Controller.wrap(msg.sender); Controller controller = Controller.wrap(msg.sender);
_locks[controller][context].expiry = expiry; _extendLock(controller, context, expiry);
} }
error InsufficientBalance();
error Locked();
error AlreadyLocked();
error ExpiryPastMaximum();
error InvalidExpiry();
error LockExpired();
} }

161
contracts/VaultBase.sol Normal file
View File

@ -0,0 +1,161 @@
// SPDX-License-Identifier: MIT
pragma solidity 0.8.28;
import "@openzeppelin/contracts/token/ERC20/IERC20.sol";
import "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol";
import "./Timestamps.sol";
using SafeERC20 for IERC20;
using Timestamps for Timestamp;
abstract contract VaultBase {
IERC20 internal immutable _token;
type Controller is address;
type Context is bytes32;
type Recipient is address;
struct Lock {
Timestamp expiry;
Timestamp maximum;
}
mapping(Controller => mapping(Context => Lock)) private _locks;
mapping(Controller => mapping(Context => mapping(Recipient => uint256)))
private _available;
mapping(Controller => mapping(Context => mapping(Recipient => uint256)))
private _designated;
constructor(IERC20 token) {
_token = token;
}
function _getBalance(
Controller controller,
Context context,
Recipient recipient
) internal view returns (uint256) {
return
_available[controller][context][recipient] +
_designated[controller][context][recipient];
}
function _getDesignated(
Controller controller,
Context context,
Recipient recipient
) internal view returns (uint256) {
return _designated[controller][context][recipient];
}
function _getLock(
Controller controller,
Context context
) internal view returns (Lock memory) {
return _locks[controller][context];
}
function _deposit(
Controller controller,
Context context,
address from,
uint256 amount
) internal {
Recipient recipient = Recipient.wrap(from);
_available[controller][context][recipient] += amount;
_token.safeTransferFrom(from, address(this), amount);
}
function _delete(
Controller controller,
Context context,
Recipient recipient
) private {
delete _available[controller][context][recipient];
delete _designated[controller][context][recipient];
}
function _withdraw(
Controller controller,
Context context,
Recipient recipient
) internal {
require(!_getLock(controller, context).expiry.isFuture(), Locked());
delete _locks[controller][context];
uint256 amount = _getBalance(controller, context, recipient);
_delete(controller, context, recipient);
_token.safeTransfer(Recipient.unwrap(recipient), amount);
}
function _burn(
Controller controller,
Context context,
Recipient recipient
) internal {
uint256 amount = _getBalance(controller, context, recipient);
_delete(controller, context, recipient);
_token.safeTransfer(address(0xdead), amount);
}
function _transfer(
Controller controller,
Context context,
Recipient from,
Recipient to,
uint256 amount
) internal {
require(
amount <= _available[controller][context][from],
InsufficientBalance()
);
_available[controller][context][from] -= amount;
_available[controller][context][to] += amount;
}
function _designate(
Controller controller,
Context context,
Recipient recipient,
uint256 amount
) internal {
require(
amount <= _available[controller][context][recipient],
InsufficientBalance()
);
_available[controller][context][recipient] -= amount;
_designated[controller][context][recipient] += amount;
}
function _lockup(
Controller controller,
Context context,
Timestamp expiry,
Timestamp maximum
) internal {
require(
Timestamp.unwrap(_getLock(controller, context).maximum) == 0,
AlreadyLocked()
);
require(!expiry.isAfter(maximum), ExpiryPastMaximum());
_locks[controller][context] = Lock({expiry: expiry, maximum: maximum});
}
function _extendLock(
Controller controller,
Context context,
Timestamp expiry
) internal {
Lock memory previous = _getLock(controller, context);
require(previous.expiry.isFuture(), LockExpired());
require(!previous.expiry.isAfter(expiry), InvalidExpiry());
require(!expiry.isAfter(previous.maximum), ExpiryPastMaximum());
_locks[controller][context].expiry = expiry;
}
error InsufficientBalance();
error Locked();
error AlreadyLocked();
error ExpiryPastMaximum();
error InvalidExpiry();
error LockExpired();
}

View File

@ -6,6 +6,7 @@ const { currentTime, advanceTimeToForNextBlock } = require("./evm")
describe("Vault", function () { describe("Vault", function () {
let token let token
let vault let vault
let controller
let account, account2, account3 let account, account2, account3
beforeEach(async function () { beforeEach(async function () {
@ -13,7 +14,7 @@ describe("Vault", function () {
token = await TestToken.deploy() token = await TestToken.deploy()
const Vault = await ethers.getContractFactory("Vault") const Vault = await ethers.getContractFactory("Vault")
vault = await Vault.deploy(token.address) vault = await Vault.deploy(token.address)
;[, account, account2, account3] = await ethers.getSigners() ;[controller, account, account2, account3] = await ethers.getSigners()
await token.mint(account.address, 1_000_000) await token.mint(account.address, 1_000_000)
await token.mint(account2.address, 1_000_000) await token.mint(account2.address, 1_000_000)
await token.mint(account3.address, 1_000_000) await token.mint(account3.address, 1_000_000)
@ -79,13 +80,22 @@ describe("Vault", function () {
await vault.deposit(context, account.address, amount) await vault.deposit(context, account.address, amount)
}) })
it("can withdraw a deposit", async function () { it("allows controller to withdraw for a recipient", async function () {
const before = await token.balanceOf(account.address) const before = await token.balanceOf(account.address)
await vault.withdraw(context, account.address) await vault.withdraw(context, account.address)
const after = await token.balanceOf(account.address) const after = await token.balanceOf(account.address)
expect(after - before).to.equal(amount) expect(after - before).to.equal(amount)
}) })
it("allows recipient to withdraw for itself", async function () {
const before = await token.balanceOf(account.address)
await vault
.connect(account)
.withdrawByRecipient(controller.address, context)
const after = await token.balanceOf(account.address)
expect(after - before).to.equal(amount)
})
it("empties the balance when withdrawing", async function () { it("empties the balance when withdrawing", async function () {
await vault.withdraw(context, account.address) await vault.withdraw(context, account.address)
expect(await vault.balance(context, account.address)).to.equal(0) expect(await vault.balance(context, account.address)).to.equal(0)