vault: flow to multiple recipients

- changes balance from uint256 -> uint128
  so that entire Balance can be read or written
  with a single operation
- moves Lock to library
- simplifies lock checks
This commit is contained in:
Mark Spanbroek 2025-01-28 10:31:39 +01:00
parent 4f880bb08a
commit db8b06a51b
8 changed files with 197 additions and 87 deletions

View File

@ -1,22 +0,0 @@
// SPDX-License-Identifier: MIT
pragma solidity 0.8.28;
import "./Timestamps.sol";
type TokensPerSecond is int256;
using {_tokensPerSecondNegate as -} for TokensPerSecond global;
using {_tokensPerSecondEquals as ==} for TokensPerSecond global;
function _tokensPerSecondNegate(
TokensPerSecond rate
) pure returns (TokensPerSecond) {
return TokensPerSecond.wrap(-TokensPerSecond.unwrap(rate));
}
function _tokensPerSecondEquals(
TokensPerSecond a,
TokensPerSecond b
) pure returns (bool) {
return TokensPerSecond.unwrap(a) == TokensPerSecond.unwrap(b);
}

View File

@ -1,7 +1,7 @@
// SPDX-License-Identifier: MIT
pragma solidity 0.8.28;
import "./VaultBase.sol";
import "./vault/VaultBase.sol";
contract Vault is VaultBase {
// solhint-disable-next-line no-empty-blocks
@ -10,7 +10,7 @@ contract Vault is VaultBase {
function getBalance(
Context context,
Recipient recipient
) public view returns (uint256) {
) public view returns (uint128) {
Controller controller = Controller.wrap(msg.sender);
Balance memory b = _getBalance(controller, context, recipient);
return b.available + b.designated;
@ -19,7 +19,7 @@ contract Vault is VaultBase {
function getDesignatedBalance(
Context context,
Recipient recipient
) public view returns (uint256) {
) public view returns (uint128) {
Controller controller = Controller.wrap(msg.sender);
Balance memory b = _getBalance(controller, context, recipient);
return b.designated;
@ -30,7 +30,7 @@ contract Vault is VaultBase {
return _getLock(controller, context);
}
function deposit(Context context, address from, uint256 amount) public {
function deposit(Context context, address from, uint128 amount) public {
Controller controller = Controller.wrap(msg.sender);
_deposit(controller, context, from, amount);
}
@ -54,7 +54,7 @@ contract Vault is VaultBase {
Context context,
Recipient from,
Recipient to,
uint256 amount
uint128 amount
) public {
Controller controller = Controller.wrap(msg.sender);
_transfer(controller, context, from, to, amount);
@ -63,7 +63,7 @@ contract Vault is VaultBase {
function designate(
Context context,
Recipient recipient,
uint256 amount
uint128 amount
) public {
Controller controller = Controller.wrap(msg.sender);
_designate(controller, context, recipient, amount);

23
contracts/vault/Flows.sol Normal file
View File

@ -0,0 +1,23 @@
// SPDX-License-Identifier: MIT
pragma solidity 0.8.28;
import "./Timestamps.sol";
import "./TokensPerSecond.sol";
struct Flow {
Timestamp start;
TokensPerSecond rate;
}
library Flows {
function _totalAt(
Flow memory flow,
Timestamp timestamp
) internal pure returns (int128) {
int128 rate = TokensPerSecond.unwrap(flow.rate);
Timestamp start = flow.start;
Timestamp end = timestamp;
uint64 duration = Timestamp.unwrap(end) - Timestamp.unwrap(start);
return rate * int128(uint128(duration));
}
}

15
contracts/vault/Locks.sol Normal file
View File

@ -0,0 +1,15 @@
// SPDX-License-Identifier: MIT
pragma solidity 0.8.28;
import "./Timestamps.sol";
struct Lock {
Timestamp expiry;
Timestamp maximum;
}
library Locks {
function isLocked(Lock memory lock) internal view returns (bool) {
return Timestamps.currentTime() < lock.expiry;
}
}

View File

@ -0,0 +1,56 @@
// SPDX-License-Identifier: MIT
pragma solidity 0.8.28;
import "./Timestamps.sol";
type TokensPerSecond is int128;
using {_tokensPerSecondNegate as -} for TokensPerSecond global;
using {_tokensPerSecondMinus as -} for TokensPerSecond global;
using {_tokensPerSecondPlus as +} for TokensPerSecond global;
using {_tokensPerSecondEquals as ==} for TokensPerSecond global;
using {_tokensPerSecondNotEqual as !=} for TokensPerSecond global;
using {_tokensPerSecondAtLeast as >=} for TokensPerSecond global;
function _tokensPerSecondNegate(
TokensPerSecond rate
) pure returns (TokensPerSecond) {
return TokensPerSecond.wrap(-TokensPerSecond.unwrap(rate));
}
function _tokensPerSecondMinus(
TokensPerSecond a,
TokensPerSecond b
) pure returns (TokensPerSecond) {
return
TokensPerSecond.wrap(TokensPerSecond.unwrap(a) - TokensPerSecond.unwrap(b));
}
function _tokensPerSecondPlus(
TokensPerSecond a,
TokensPerSecond b
) pure returns (TokensPerSecond) {
return
TokensPerSecond.wrap(TokensPerSecond.unwrap(a) + TokensPerSecond.unwrap(b));
}
function _tokensPerSecondEquals(
TokensPerSecond a,
TokensPerSecond b
) pure returns (bool) {
return TokensPerSecond.unwrap(a) == TokensPerSecond.unwrap(b);
}
function _tokensPerSecondNotEqual(
TokensPerSecond a,
TokensPerSecond b
) pure returns (bool) {
return TokensPerSecond.unwrap(a) != TokensPerSecond.unwrap(b);
}
function _tokensPerSecondAtLeast(
TokensPerSecond a,
TokensPerSecond b
) pure returns (bool) {
return TokensPerSecond.unwrap(a) >= TokensPerSecond.unwrap(b);
}

View File

@ -5,9 +5,13 @@ import "@openzeppelin/contracts/token/ERC20/IERC20.sol";
import "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol";
import "./Timestamps.sol";
import "./TokensPerSecond.sol";
import "./Flows.sol";
import "./Locks.sol";
using SafeERC20 for IERC20;
using Timestamps for Timestamp;
using Flows for Flow;
using Locks for Lock;
abstract contract VaultBase {
IERC20 internal immutable _token;
@ -17,18 +21,8 @@ abstract contract VaultBase {
type Recipient is address;
struct Balance {
uint256 available;
uint256 designated;
}
struct Lock {
Timestamp expiry;
Timestamp maximum;
}
struct Flow {
Timestamp start;
TokensPerSecond rate;
uint128 available;
uint128 designated;
}
mapping(Controller => mapping(Context => Lock)) private _locks;
@ -46,29 +40,30 @@ abstract contract VaultBase {
Context context,
Recipient recipient
) internal view returns (Balance memory) {
Balance memory balance = _balances[controller][context][recipient];
int256 accumulated = _accumulateFlow(controller, context, recipient);
if (accumulated >= 0) {
balance.designated += uint256(accumulated);
} else {
balance.available -= uint256(-accumulated);
}
return balance;
Balance storage balance = _balances[controller][context][recipient];
Flow storage flow = _flows[controller][context][recipient];
Lock storage lock = _locks[controller][context];
Timestamp timestamp = Timestamps.currentTime();
return _getBalanceAt(balance, flow, lock, timestamp);
}
function _accumulateFlow(
Controller controller,
Context context,
Recipient recipient
) private view returns (int256) {
Flow memory flow = _flows[controller][context][recipient];
if (flow.rate == TokensPerSecond.wrap(0)) {
return 0;
function _getBalanceAt(
Balance memory balance,
Flow memory flow,
Lock storage lock,
Timestamp timestamp
) private view returns (Balance memory) {
Balance memory result = balance;
if (flow.rate != TokensPerSecond.wrap(0)) {
Timestamp end = Timestamps.earliest(timestamp, lock.expiry);
int128 accumulated = flow._totalAt(end);
if (accumulated >= 0) {
result.designated += uint128(accumulated);
} else {
result.available -= uint128(-accumulated);
}
}
Timestamp expiry = _getLock(controller, context).expiry;
Timestamp flowEnd = Timestamps.earliest(Timestamps.currentTime(), expiry);
uint64 duration = Timestamp.unwrap(flowEnd) - Timestamp.unwrap(flow.start);
return TokensPerSecond.unwrap(flow.rate) * int256(uint256(duration));
return result;
}
function _getLock(
@ -82,7 +77,7 @@ abstract contract VaultBase {
Controller controller,
Context context,
address from,
uint256 amount
uint128 amount
) internal {
Recipient recipient = Recipient.wrap(from);
_balances[controller][context][recipient].available += amount;
@ -102,13 +97,10 @@ abstract contract VaultBase {
Context context,
Recipient recipient
) internal {
require(
_getLock(controller, context).expiry <= Timestamps.currentTime(),
Locked()
);
require(!_locks[controller][context].isLocked(), Locked());
delete _locks[controller][context];
Balance memory balance = _getBalance(controller, context, recipient);
uint256 amount = balance.available + balance.designated;
uint128 amount = balance.available + balance.designated;
_delete(controller, context, recipient);
_token.safeTransfer(Recipient.unwrap(recipient), amount);
}
@ -119,7 +111,7 @@ abstract contract VaultBase {
Recipient recipient
) internal {
Balance memory balance = _getBalance(controller, context, recipient);
uint256 amount = balance.available + balance.designated;
uint128 amount = balance.available + balance.designated;
_delete(controller, context, recipient);
_token.safeTransfer(address(0xdead), amount);
}
@ -129,7 +121,7 @@ abstract contract VaultBase {
Context context,
Recipient from,
Recipient to,
uint256 amount
uint128 amount
) internal {
require(
amount <= _balances[controller][context][from].available,
@ -143,7 +135,7 @@ abstract contract VaultBase {
Controller controller,
Context context,
Recipient recipient,
uint256 amount
uint128 amount
) internal {
Balance storage balance = _balances[controller][context][recipient];
require(amount <= balance.available, InsufficientBalance());
@ -157,9 +149,9 @@ abstract contract VaultBase {
Timestamp expiry,
Timestamp maximum
) internal {
Lock memory existing = _getLock(controller, context);
require(existing.maximum == Timestamp.wrap(0), AlreadyLocked());
require(expiry <= maximum, ExpiryPastMaximum());
Lock memory existing = _locks[controller][context];
require(existing.maximum == Timestamp.wrap(0), AlreadyLocked());
_locks[controller][context] = Lock({expiry: expiry, maximum: maximum});
}
@ -168,10 +160,10 @@ abstract contract VaultBase {
Context context,
Timestamp expiry
) internal {
Lock memory existing = _getLock(controller, context);
require(Timestamps.currentTime() < existing.expiry, LockExpired());
require(existing.expiry <= expiry, InvalidExpiry());
require(expiry <= existing.maximum, ExpiryPastMaximum());
Lock memory lock = _locks[controller][context];
require(lock.isLocked(), LockRequired());
require(lock.expiry <= expiry, InvalidExpiry());
require(expiry <= lock.maximum, ExpiryPastMaximum());
_locks[controller][context].expiry = expiry;
}
@ -182,16 +174,25 @@ abstract contract VaultBase {
Recipient to,
TokensPerSecond rate
) internal {
Lock memory lock = _getLock(controller, context);
require(rate >= TokensPerSecond.wrap(0), NegativeFlow());
Lock memory lock = _locks[controller][context];
require(lock.isLocked(), LockRequired());
Timestamp start = Timestamps.currentTime();
require(lock.expiry != Timestamp.wrap(0), LockRequired());
require(start < lock.expiry, LockExpired());
uint64 duration = Timestamp.unwrap(lock.maximum) - Timestamp.unwrap(start);
int256 total = int256(uint256(duration)) * TokensPerSecond.unwrap(rate);
Balance memory balance = _getBalance(controller, context, from);
require(total <= int256(balance.available), InsufficientBalance());
_flows[controller][context][to] = Flow({start: start, rate: rate});
_flows[controller][context][from] = Flow({start: start, rate: -rate});
Flow memory senderFlow = _flows[controller][context][from];
senderFlow.start = start;
senderFlow.rate = senderFlow.rate - rate;
Flow memory receiverFlow = _flows[controller][context][to];
receiverFlow.start = start;
receiverFlow.rate = receiverFlow.rate + rate;
Balance memory senderBalance = _getBalance(controller, context, from);
uint128 flowMaximum = uint128(-senderFlow._totalAt(lock.maximum));
require(flowMaximum <= senderBalance.available, InsufficientBalance());
_flows[controller][context][from] = senderFlow;
_flows[controller][context][to] = receiverFlow;
}
error InsufficientBalance();
@ -199,6 +200,6 @@ abstract contract VaultBase {
error AlreadyLocked();
error ExpiryPastMaximum();
error InvalidExpiry();
error LockExpired();
error LockRequired();
error NegativeFlow();
}

View File

@ -5,6 +5,7 @@ const {
currentTime,
advanceTimeTo,
mine,
setAutomine,
snapshot,
revert,
} = require("./evm")
@ -380,7 +381,7 @@ describe("Vault", function () {
await vault.lock(context, expiry, maximum)
await advanceTimeTo(expiry)
const extending = vault.extendLock(context, maximum)
await expect(extending).to.be.revertedWith("LockExpired")
await expect(extending).to.be.revertedWith("LockRequired")
})
it("allows locked tokens to be burned", async function () {
@ -404,12 +405,14 @@ describe("Vault", function () {
let sender
let receiver
let receiver2
beforeEach(async function () {
await token.connect(account).approve(vault.address, deposit)
await vault.deposit(context, account.address, deposit)
sender = account.address
receiver = account2.address
receiver2 = account3.address
})
async function getBalance(recipient) {
return await vault.getBalance(context, recipient)
@ -427,7 +430,7 @@ describe("Vault", function () {
await vault.lock(context, expiry, expiry)
await advanceTimeTo(expiry)
await expect(vault.flow(context, sender, receiver, 2)).to.be.revertedWith(
"LockExpired"
"LockRequired"
)
})
@ -452,6 +455,22 @@ describe("Vault", function () {
expect(await getBalance(receiver)).to.equal(8)
})
it("can move tokens to several different recipients", async function () {
await setAutomine(false)
await vault.flow(context, sender, receiver, 1)
await vault.flow(context, sender, receiver2, 2)
await mine()
const start = await currentTime()
await advanceTimeTo(start + 2)
expect(await getBalance(sender)).to.equal(deposit - 6)
expect(await getBalance(receiver)).to.equal(2)
expect(await getBalance(receiver2)).to.equal(4)
await advanceTimeTo(start + 4)
expect(await getBalance(sender)).to.equal(deposit - 12)
expect(await getBalance(receiver)).to.equal(4)
expect(await getBalance(receiver2)).to.equal(8)
})
it("designates tokens that flow for the recipient", async function () {
await vault.flow(context, sender, receiver, 3)
const start = await currentTime()
@ -491,6 +510,24 @@ describe("Vault", function () {
vault.flow(context, sender, receiver, rate + 1)
).to.be.revertedWith("InsufficientBalance")
})
it("rejects total flows exceeding available tokens", async function () {
const duration = maximum - (await currentTime())
const rate = Math.round(((2 / 3) * deposit) / duration)
await vault.flow(context, sender, receiver, rate)
await expect(
vault.flow(context, sender, receiver, rate)
).to.be.revertedWith("InsufficientBalance")
})
it("cannot flow designated tokens", async function () {
await vault.designate(context, sender, 10)
const duration = maximum - (await currentTime())
const rate = Math.round(deposit / duration)
await expect(
vault.flow(context, sender, receiver, rate)
).to.be.revertedWith("InsufficientBalance")
})
})
})
})