diff --git a/contracts/TokensPerSecond.sol b/contracts/TokensPerSecond.sol deleted file mode 100644 index 455a66e..0000000 --- a/contracts/TokensPerSecond.sol +++ /dev/null @@ -1,22 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity 0.8.28; - -import "./Timestamps.sol"; - -type TokensPerSecond is int256; - -using {_tokensPerSecondNegate as -} for TokensPerSecond global; -using {_tokensPerSecondEquals as ==} for TokensPerSecond global; - -function _tokensPerSecondNegate( - TokensPerSecond rate -) pure returns (TokensPerSecond) { - return TokensPerSecond.wrap(-TokensPerSecond.unwrap(rate)); -} - -function _tokensPerSecondEquals( - TokensPerSecond a, - TokensPerSecond b -) pure returns (bool) { - return TokensPerSecond.unwrap(a) == TokensPerSecond.unwrap(b); -} diff --git a/contracts/Vault.sol b/contracts/Vault.sol index 7acc5fb..1bb7e8b 100644 --- a/contracts/Vault.sol +++ b/contracts/Vault.sol @@ -1,7 +1,7 @@ // SPDX-License-Identifier: MIT pragma solidity 0.8.28; -import "./VaultBase.sol"; +import "./vault/VaultBase.sol"; contract Vault is VaultBase { // solhint-disable-next-line no-empty-blocks @@ -10,7 +10,7 @@ contract Vault is VaultBase { function getBalance( Context context, Recipient recipient - ) public view returns (uint256) { + ) public view returns (uint128) { Controller controller = Controller.wrap(msg.sender); Balance memory b = _getBalance(controller, context, recipient); return b.available + b.designated; @@ -19,7 +19,7 @@ contract Vault is VaultBase { function getDesignatedBalance( Context context, Recipient recipient - ) public view returns (uint256) { + ) public view returns (uint128) { Controller controller = Controller.wrap(msg.sender); Balance memory b = _getBalance(controller, context, recipient); return b.designated; @@ -30,7 +30,7 @@ contract Vault is VaultBase { return _getLock(controller, context); } - function deposit(Context context, address from, uint256 amount) public { + function deposit(Context context, address from, uint128 amount) public { Controller controller = Controller.wrap(msg.sender); _deposit(controller, context, from, amount); } @@ -54,7 +54,7 @@ contract Vault is VaultBase { Context context, Recipient from, Recipient to, - uint256 amount + uint128 amount ) public { Controller controller = Controller.wrap(msg.sender); _transfer(controller, context, from, to, amount); @@ -63,7 +63,7 @@ contract Vault is VaultBase { function designate( Context context, Recipient recipient, - uint256 amount + uint128 amount ) public { Controller controller = Controller.wrap(msg.sender); _designate(controller, context, recipient, amount); diff --git a/contracts/vault/Flows.sol b/contracts/vault/Flows.sol new file mode 100644 index 0000000..4363f40 --- /dev/null +++ b/contracts/vault/Flows.sol @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +pragma solidity 0.8.28; + +import "./Timestamps.sol"; +import "./TokensPerSecond.sol"; + +struct Flow { + Timestamp start; + TokensPerSecond rate; +} + +library Flows { + function _totalAt( + Flow memory flow, + Timestamp timestamp + ) internal pure returns (int128) { + int128 rate = TokensPerSecond.unwrap(flow.rate); + Timestamp start = flow.start; + Timestamp end = timestamp; + uint64 duration = Timestamp.unwrap(end) - Timestamp.unwrap(start); + return rate * int128(uint128(duration)); + } +} diff --git a/contracts/vault/Locks.sol b/contracts/vault/Locks.sol new file mode 100644 index 0000000..276a3f1 --- /dev/null +++ b/contracts/vault/Locks.sol @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: MIT +pragma solidity 0.8.28; + +import "./Timestamps.sol"; + +struct Lock { + Timestamp expiry; + Timestamp maximum; +} + +library Locks { + function isLocked(Lock memory lock) internal view returns (bool) { + return Timestamps.currentTime() < lock.expiry; + } +} diff --git a/contracts/Timestamps.sol b/contracts/vault/Timestamps.sol similarity index 100% rename from contracts/Timestamps.sol rename to contracts/vault/Timestamps.sol diff --git a/contracts/vault/TokensPerSecond.sol b/contracts/vault/TokensPerSecond.sol new file mode 100644 index 0000000..174be78 --- /dev/null +++ b/contracts/vault/TokensPerSecond.sol @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +pragma solidity 0.8.28; + +import "./Timestamps.sol"; + +type TokensPerSecond is int128; + +using {_tokensPerSecondNegate as -} for TokensPerSecond global; +using {_tokensPerSecondMinus as -} for TokensPerSecond global; +using {_tokensPerSecondPlus as +} for TokensPerSecond global; +using {_tokensPerSecondEquals as ==} for TokensPerSecond global; +using {_tokensPerSecondNotEqual as !=} for TokensPerSecond global; +using {_tokensPerSecondAtLeast as >=} for TokensPerSecond global; + +function _tokensPerSecondNegate( + TokensPerSecond rate +) pure returns (TokensPerSecond) { + return TokensPerSecond.wrap(-TokensPerSecond.unwrap(rate)); +} + +function _tokensPerSecondMinus( + TokensPerSecond a, + TokensPerSecond b +) pure returns (TokensPerSecond) { + return + TokensPerSecond.wrap(TokensPerSecond.unwrap(a) - TokensPerSecond.unwrap(b)); +} + +function _tokensPerSecondPlus( + TokensPerSecond a, + TokensPerSecond b +) pure returns (TokensPerSecond) { + return + TokensPerSecond.wrap(TokensPerSecond.unwrap(a) + TokensPerSecond.unwrap(b)); +} + +function _tokensPerSecondEquals( + TokensPerSecond a, + TokensPerSecond b +) pure returns (bool) { + return TokensPerSecond.unwrap(a) == TokensPerSecond.unwrap(b); +} + +function _tokensPerSecondNotEqual( + TokensPerSecond a, + TokensPerSecond b +) pure returns (bool) { + return TokensPerSecond.unwrap(a) != TokensPerSecond.unwrap(b); +} + +function _tokensPerSecondAtLeast( + TokensPerSecond a, + TokensPerSecond b +) pure returns (bool) { + return TokensPerSecond.unwrap(a) >= TokensPerSecond.unwrap(b); +} diff --git a/contracts/VaultBase.sol b/contracts/vault/VaultBase.sol similarity index 61% rename from contracts/VaultBase.sol rename to contracts/vault/VaultBase.sol index 5aee717..b788c6c 100644 --- a/contracts/VaultBase.sol +++ b/contracts/vault/VaultBase.sol @@ -5,9 +5,13 @@ import "@openzeppelin/contracts/token/ERC20/IERC20.sol"; import "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol"; import "./Timestamps.sol"; import "./TokensPerSecond.sol"; +import "./Flows.sol"; +import "./Locks.sol"; using SafeERC20 for IERC20; using Timestamps for Timestamp; +using Flows for Flow; +using Locks for Lock; abstract contract VaultBase { IERC20 internal immutable _token; @@ -17,18 +21,8 @@ abstract contract VaultBase { type Recipient is address; struct Balance { - uint256 available; - uint256 designated; - } - - struct Lock { - Timestamp expiry; - Timestamp maximum; - } - - struct Flow { - Timestamp start; - TokensPerSecond rate; + uint128 available; + uint128 designated; } mapping(Controller => mapping(Context => Lock)) private _locks; @@ -46,29 +40,30 @@ abstract contract VaultBase { Context context, Recipient recipient ) internal view returns (Balance memory) { - Balance memory balance = _balances[controller][context][recipient]; - int256 accumulated = _accumulateFlow(controller, context, recipient); - if (accumulated >= 0) { - balance.designated += uint256(accumulated); - } else { - balance.available -= uint256(-accumulated); - } - return balance; + Balance storage balance = _balances[controller][context][recipient]; + Flow storage flow = _flows[controller][context][recipient]; + Lock storage lock = _locks[controller][context]; + Timestamp timestamp = Timestamps.currentTime(); + return _getBalanceAt(balance, flow, lock, timestamp); } - function _accumulateFlow( - Controller controller, - Context context, - Recipient recipient - ) private view returns (int256) { - Flow memory flow = _flows[controller][context][recipient]; - if (flow.rate == TokensPerSecond.wrap(0)) { - return 0; + function _getBalanceAt( + Balance memory balance, + Flow memory flow, + Lock storage lock, + Timestamp timestamp + ) private view returns (Balance memory) { + Balance memory result = balance; + if (flow.rate != TokensPerSecond.wrap(0)) { + Timestamp end = Timestamps.earliest(timestamp, lock.expiry); + int128 accumulated = flow._totalAt(end); + if (accumulated >= 0) { + result.designated += uint128(accumulated); + } else { + result.available -= uint128(-accumulated); + } } - Timestamp expiry = _getLock(controller, context).expiry; - Timestamp flowEnd = Timestamps.earliest(Timestamps.currentTime(), expiry); - uint64 duration = Timestamp.unwrap(flowEnd) - Timestamp.unwrap(flow.start); - return TokensPerSecond.unwrap(flow.rate) * int256(uint256(duration)); + return result; } function _getLock( @@ -82,7 +77,7 @@ abstract contract VaultBase { Controller controller, Context context, address from, - uint256 amount + uint128 amount ) internal { Recipient recipient = Recipient.wrap(from); _balances[controller][context][recipient].available += amount; @@ -102,13 +97,10 @@ abstract contract VaultBase { Context context, Recipient recipient ) internal { - require( - _getLock(controller, context).expiry <= Timestamps.currentTime(), - Locked() - ); + require(!_locks[controller][context].isLocked(), Locked()); delete _locks[controller][context]; Balance memory balance = _getBalance(controller, context, recipient); - uint256 amount = balance.available + balance.designated; + uint128 amount = balance.available + balance.designated; _delete(controller, context, recipient); _token.safeTransfer(Recipient.unwrap(recipient), amount); } @@ -119,7 +111,7 @@ abstract contract VaultBase { Recipient recipient ) internal { Balance memory balance = _getBalance(controller, context, recipient); - uint256 amount = balance.available + balance.designated; + uint128 amount = balance.available + balance.designated; _delete(controller, context, recipient); _token.safeTransfer(address(0xdead), amount); } @@ -129,7 +121,7 @@ abstract contract VaultBase { Context context, Recipient from, Recipient to, - uint256 amount + uint128 amount ) internal { require( amount <= _balances[controller][context][from].available, @@ -143,7 +135,7 @@ abstract contract VaultBase { Controller controller, Context context, Recipient recipient, - uint256 amount + uint128 amount ) internal { Balance storage balance = _balances[controller][context][recipient]; require(amount <= balance.available, InsufficientBalance()); @@ -157,9 +149,9 @@ abstract contract VaultBase { Timestamp expiry, Timestamp maximum ) internal { - Lock memory existing = _getLock(controller, context); - require(existing.maximum == Timestamp.wrap(0), AlreadyLocked()); require(expiry <= maximum, ExpiryPastMaximum()); + Lock memory existing = _locks[controller][context]; + require(existing.maximum == Timestamp.wrap(0), AlreadyLocked()); _locks[controller][context] = Lock({expiry: expiry, maximum: maximum}); } @@ -168,10 +160,10 @@ abstract contract VaultBase { Context context, Timestamp expiry ) internal { - Lock memory existing = _getLock(controller, context); - require(Timestamps.currentTime() < existing.expiry, LockExpired()); - require(existing.expiry <= expiry, InvalidExpiry()); - require(expiry <= existing.maximum, ExpiryPastMaximum()); + Lock memory lock = _locks[controller][context]; + require(lock.isLocked(), LockRequired()); + require(lock.expiry <= expiry, InvalidExpiry()); + require(expiry <= lock.maximum, ExpiryPastMaximum()); _locks[controller][context].expiry = expiry; } @@ -182,16 +174,25 @@ abstract contract VaultBase { Recipient to, TokensPerSecond rate ) internal { - Lock memory lock = _getLock(controller, context); + require(rate >= TokensPerSecond.wrap(0), NegativeFlow()); + + Lock memory lock = _locks[controller][context]; + require(lock.isLocked(), LockRequired()); + Timestamp start = Timestamps.currentTime(); - require(lock.expiry != Timestamp.wrap(0), LockRequired()); - require(start < lock.expiry, LockExpired()); - uint64 duration = Timestamp.unwrap(lock.maximum) - Timestamp.unwrap(start); - int256 total = int256(uint256(duration)) * TokensPerSecond.unwrap(rate); - Balance memory balance = _getBalance(controller, context, from); - require(total <= int256(balance.available), InsufficientBalance()); - _flows[controller][context][to] = Flow({start: start, rate: rate}); - _flows[controller][context][from] = Flow({start: start, rate: -rate}); + Flow memory senderFlow = _flows[controller][context][from]; + senderFlow.start = start; + senderFlow.rate = senderFlow.rate - rate; + Flow memory receiverFlow = _flows[controller][context][to]; + receiverFlow.start = start; + receiverFlow.rate = receiverFlow.rate + rate; + + Balance memory senderBalance = _getBalance(controller, context, from); + uint128 flowMaximum = uint128(-senderFlow._totalAt(lock.maximum)); + require(flowMaximum <= senderBalance.available, InsufficientBalance()); + + _flows[controller][context][from] = senderFlow; + _flows[controller][context][to] = receiverFlow; } error InsufficientBalance(); @@ -199,6 +200,6 @@ abstract contract VaultBase { error AlreadyLocked(); error ExpiryPastMaximum(); error InvalidExpiry(); - error LockExpired(); error LockRequired(); + error NegativeFlow(); } diff --git a/test/Vault.tests.js b/test/Vault.tests.js index c34ef41..44f702d 100644 --- a/test/Vault.tests.js +++ b/test/Vault.tests.js @@ -5,6 +5,7 @@ const { currentTime, advanceTimeTo, mine, + setAutomine, snapshot, revert, } = require("./evm") @@ -380,7 +381,7 @@ describe("Vault", function () { await vault.lock(context, expiry, maximum) await advanceTimeTo(expiry) const extending = vault.extendLock(context, maximum) - await expect(extending).to.be.revertedWith("LockExpired") + await expect(extending).to.be.revertedWith("LockRequired") }) it("allows locked tokens to be burned", async function () { @@ -404,12 +405,14 @@ describe("Vault", function () { let sender let receiver + let receiver2 beforeEach(async function () { await token.connect(account).approve(vault.address, deposit) await vault.deposit(context, account.address, deposit) sender = account.address receiver = account2.address + receiver2 = account3.address }) async function getBalance(recipient) { return await vault.getBalance(context, recipient) @@ -427,7 +430,7 @@ describe("Vault", function () { await vault.lock(context, expiry, expiry) await advanceTimeTo(expiry) await expect(vault.flow(context, sender, receiver, 2)).to.be.revertedWith( - "LockExpired" + "LockRequired" ) }) @@ -452,6 +455,22 @@ describe("Vault", function () { expect(await getBalance(receiver)).to.equal(8) }) + it("can move tokens to several different recipients", async function () { + await setAutomine(false) + await vault.flow(context, sender, receiver, 1) + await vault.flow(context, sender, receiver2, 2) + await mine() + const start = await currentTime() + await advanceTimeTo(start + 2) + expect(await getBalance(sender)).to.equal(deposit - 6) + expect(await getBalance(receiver)).to.equal(2) + expect(await getBalance(receiver2)).to.equal(4) + await advanceTimeTo(start + 4) + expect(await getBalance(sender)).to.equal(deposit - 12) + expect(await getBalance(receiver)).to.equal(4) + expect(await getBalance(receiver2)).to.equal(8) + }) + it("designates tokens that flow for the recipient", async function () { await vault.flow(context, sender, receiver, 3) const start = await currentTime() @@ -491,6 +510,24 @@ describe("Vault", function () { vault.flow(context, sender, receiver, rate + 1) ).to.be.revertedWith("InsufficientBalance") }) + + it("rejects total flows exceeding available tokens", async function () { + const duration = maximum - (await currentTime()) + const rate = Math.round(((2 / 3) * deposit) / duration) + await vault.flow(context, sender, receiver, rate) + await expect( + vault.flow(context, sender, receiver, rate) + ).to.be.revertedWith("InsufficientBalance") + }) + + it("cannot flow designated tokens", async function () { + await vault.designate(context, sender, 10) + const duration = maximum - (await currentTime()) + const rate = Math.round(deposit / duration) + await expect( + vault.flow(context, sender, receiver, rate) + ).to.be.revertedWith("InsufficientBalance") + }) }) }) })