diff --git a/contracts/Timestamps.sol b/contracts/Timestamps.sol index 96756b7..f751e05 100644 --- a/contracts/Timestamps.sol +++ b/contracts/Timestamps.sol @@ -3,9 +3,14 @@ pragma solidity 0.8.28; type Timestamp is uint64; +using {_notEquals as !=} for Timestamp global; using {_lessThan as <} for Timestamp global; using {_atMost as <=} for Timestamp global; +function _notEquals(Timestamp a, Timestamp b) pure returns (bool) { + return Timestamp.unwrap(a) != Timestamp.unwrap(b); +} + function _lessThan(Timestamp a, Timestamp b) pure returns (bool) { return Timestamp.unwrap(a) < Timestamp.unwrap(b); } @@ -18,4 +23,15 @@ library Timestamps { function currentTime() internal view returns (Timestamp) { return Timestamp.wrap(uint64(block.timestamp)); } + + function earliest( + Timestamp a, + Timestamp b + ) internal pure returns (Timestamp) { + if (a <= b) { + return a; + } else { + return b; + } + } } diff --git a/contracts/VaultBase.sol b/contracts/VaultBase.sol index 3789b2e..2b31273 100644 --- a/contracts/VaultBase.sol +++ b/contracts/VaultBase.sol @@ -47,8 +47,7 @@ abstract contract VaultBase { Recipient recipient ) internal view returns (Balance memory) { Balance memory balance = _balances[controller][context][recipient]; - Flow memory flow = _flows[controller][context][recipient]; - int256 accumulated = _accumulate(flow, Timestamps.currentTime()); + int256 accumulated = _accumulateFlow(controller, context, recipient); if (accumulated >= 0) { balance.designated += uint256(accumulated); } else { @@ -57,14 +56,18 @@ abstract contract VaultBase { return balance; } - function _accumulate( - Flow memory flow, - Timestamp end - ) private pure returns (int256) { - if (TokensPerSecond.unwrap(flow.rate) == 0) { + function _accumulateFlow( + Controller controller, + Context context, + Recipient recipient + ) private view returns (int256) { + Flow memory flow = _flows[controller][context][recipient]; + if (flow.rate == TokensPerSecond.wrap(0)) { return 0; } - uint64 duration = Timestamp.unwrap(end) - Timestamp.unwrap(flow.start); + Timestamp expiry = _getLock(controller, context).expiry; + Timestamp flowEnd = Timestamps.earliest(Timestamps.currentTime(), expiry); + uint64 duration = Timestamp.unwrap(flowEnd) - Timestamp.unwrap(flow.start); return TokensPerSecond.unwrap(flow.rate) * int256(uint256(duration)); } @@ -99,7 +102,10 @@ abstract contract VaultBase { Context context, Recipient recipient ) internal { - require(_getLock(controller, context).expiry <= Timestamps.currentTime(), Locked()); + require( + _getLock(controller, context).expiry <= Timestamps.currentTime(), + Locked() + ); delete _locks[controller][context]; Balance memory balance = _getBalance(controller, context, recipient); uint256 amount = balance.available + balance.designated; @@ -178,6 +184,10 @@ abstract contract VaultBase { Recipient to, TokensPerSecond rate ) internal { + require( + _getLock(controller, context).expiry != Timestamp.wrap(0), + LockRequired() + ); Timestamp start = Timestamps.currentTime(); _flows[controller][context][to] = Flow({start: start, rate: rate}); _flows[controller][context][from] = Flow({start: start, rate: -rate}); @@ -189,4 +199,5 @@ abstract contract VaultBase { error ExpiryPastMaximum(); error InvalidExpiry(); error LockExpired(); + error LockRequired(); } diff --git a/test/Vault.tests.js b/test/Vault.tests.js index af28571..9ed3fa0 100644 --- a/test/Vault.tests.js +++ b/test/Vault.tests.js @@ -385,9 +385,14 @@ describe("Vault", function () { const context = randomBytes(32) const amount = 42 + let sender + let receiver + beforeEach(async function () { await token.connect(account).approve(vault.address, amount) await vault.deposit(context, account.address, amount) + sender = account.address + receiver = account2.address }) async function advanceTimeTo(timestamp) { @@ -395,22 +400,49 @@ describe("Vault", function () { await mine() } - it("moves tokens over time", async function () { - await vault.flow(context, account.address, account2.address, 2) - const start = await currentTime() - await advanceTimeTo(start + 2) - expect(await vault.balance(context, account.address)).to.equal(amount - 4) - expect(await vault.balance(context, account2.address)).to.equal(4) - await advanceTimeTo(start + 4) - expect(await vault.balance(context, account.address)).to.equal(amount - 8) - expect(await vault.balance(context, account2.address)).to.equal(8) + it("requires that a lock is set", async function () { + await expect(vault.flow(context, sender, receiver, 2)).to.be.revertedWith( + "LockRequired" + ) }) - it("designates tokens that flow for the recipient", async function () { - await vault.flow(context, account.address, account2.address, 3) - const start = await currentTime() - await advanceTimeTo(start + 7) - expect(await vault.designated(context, account2.address)).to.equal(21) + describe("when a lock is set", async function () { + let expiry + + beforeEach(async function () { + expiry = (await currentTime()) + 20 + await vault.lockup(context, expiry, expiry) + }) + + it("moves tokens over time", async function () { + await vault.flow(context, sender, receiver, 2) + const start = await currentTime() + await advanceTimeTo(start + 2) + expect(await vault.balance(context, sender)).to.equal(amount - 4) + expect(await vault.balance(context, receiver)).to.equal(4) + await advanceTimeTo(start + 4) + expect(await vault.balance(context, sender)).to.equal(amount - 8) + expect(await vault.balance(context, receiver)).to.equal(8) + }) + + it("designates tokens that flow for the recipient", async function () { + await vault.flow(context, sender, receiver, 3) + const start = await currentTime() + await advanceTimeTo(start + 7) + expect(await vault.designated(context, receiver)).to.equal(21) + }) + + it("stops flowing when lock expires", async function () { + await vault.flow(context, sender, receiver, 2) + const start = await currentTime() + await advanceTimeTo(expiry) + const total = (expiry - start) * 2 + expect(await vault.balance(context, sender)).to.equal(amount - total) + expect(await vault.balance(context, receiver)).to.equal(total) + await advanceTimeTo(expiry + 10) + expect(await vault.balance(context, sender)).to.equal(amount - total) + expect(await vault.balance(context, receiver)).to.equal(total) + }) }) }) })