diff --git a/contracts/Timestamps.sol b/contracts/Timestamps.sol index f751e05..6e2dcda 100644 --- a/contracts/Timestamps.sol +++ b/contracts/Timestamps.sol @@ -3,19 +3,24 @@ pragma solidity 0.8.28; type Timestamp is uint64; -using {_notEquals as !=} for Timestamp global; -using {_lessThan as <} for Timestamp global; -using {_atMost as <=} for Timestamp global; +using {_timestampEquals as ==} for Timestamp global; +using {_timestampNotEqual as !=} for Timestamp global; +using {_timestampLessThan as <} for Timestamp global; +using {_timestampAtMost as <=} for Timestamp global; -function _notEquals(Timestamp a, Timestamp b) pure returns (bool) { +function _timestampEquals(Timestamp a, Timestamp b) pure returns (bool) { + return Timestamp.unwrap(a) == Timestamp.unwrap(b); +} + +function _timestampNotEqual(Timestamp a, Timestamp b) pure returns (bool) { return Timestamp.unwrap(a) != Timestamp.unwrap(b); } -function _lessThan(Timestamp a, Timestamp b) pure returns (bool) { +function _timestampLessThan(Timestamp a, Timestamp b) pure returns (bool) { return Timestamp.unwrap(a) < Timestamp.unwrap(b); } -function _atMost(Timestamp a, Timestamp b) pure returns (bool) { +function _timestampAtMost(Timestamp a, Timestamp b) pure returns (bool) { return Timestamp.unwrap(a) <= Timestamp.unwrap(b); } diff --git a/contracts/TokensPerSecond.sol b/contracts/TokensPerSecond.sol index 677da3a..455a66e 100644 --- a/contracts/TokensPerSecond.sol +++ b/contracts/TokensPerSecond.sol @@ -5,13 +5,18 @@ import "./Timestamps.sol"; type TokensPerSecond is int256; -using {_negate as -} for TokensPerSecond global; -using {_equals as ==} for TokensPerSecond global; +using {_tokensPerSecondNegate as -} for TokensPerSecond global; +using {_tokensPerSecondEquals as ==} for TokensPerSecond global; -function _negate(TokensPerSecond rate) pure returns (TokensPerSecond) { +function _tokensPerSecondNegate( + TokensPerSecond rate +) pure returns (TokensPerSecond) { return TokensPerSecond.wrap(-TokensPerSecond.unwrap(rate)); } -function _equals(TokensPerSecond a, TokensPerSecond b) pure returns (bool) { +function _tokensPerSecondEquals( + TokensPerSecond a, + TokensPerSecond b +) pure returns (bool) { return TokensPerSecond.unwrap(a) == TokensPerSecond.unwrap(b); } diff --git a/contracts/VaultBase.sol b/contracts/VaultBase.sol index d3b8fad..5aee717 100644 --- a/contracts/VaultBase.sol +++ b/contracts/VaultBase.sol @@ -157,10 +157,8 @@ abstract contract VaultBase { Timestamp expiry, Timestamp maximum ) internal { - require( - Timestamp.unwrap(_getLock(controller, context).maximum) == 0, - AlreadyLocked() - ); + Lock memory existing = _getLock(controller, context); + require(existing.maximum == Timestamp.wrap(0), AlreadyLocked()); require(expiry <= maximum, ExpiryPastMaximum()); _locks[controller][context] = Lock({expiry: expiry, maximum: maximum}); } @@ -170,10 +168,10 @@ abstract contract VaultBase { Context context, Timestamp expiry ) internal { - Lock memory previous = _getLock(controller, context); - require(Timestamps.currentTime() < previous.expiry, LockExpired()); - require(previous.expiry <= expiry, InvalidExpiry()); - require(expiry <= previous.maximum, ExpiryPastMaximum()); + Lock memory existing = _getLock(controller, context); + require(Timestamps.currentTime() < existing.expiry, LockExpired()); + require(existing.expiry <= expiry, InvalidExpiry()); + require(expiry <= existing.maximum, ExpiryPastMaximum()); _locks[controller][context].expiry = expiry; } @@ -185,8 +183,9 @@ abstract contract VaultBase { TokensPerSecond rate ) internal { Lock memory lock = _getLock(controller, context); - require(lock.expiry != Timestamp.wrap(0), LockRequired()); Timestamp start = Timestamps.currentTime(); + require(lock.expiry != Timestamp.wrap(0), LockRequired()); + require(start < lock.expiry, LockExpired()); uint64 duration = Timestamp.unwrap(lock.maximum) - Timestamp.unwrap(start); int256 total = int256(uint256(duration)) * TokensPerSecond.unwrap(rate); Balance memory balance = _getBalance(controller, context, from); diff --git a/test/Vault.tests.js b/test/Vault.tests.js index 74da3c7..ca58a6d 100644 --- a/test/Vault.tests.js +++ b/test/Vault.tests.js @@ -406,6 +406,15 @@ describe("Vault", function () { ) }) + it("requires that the lock is not expired", async function () { + let expiry = (await currentTime()) + 20 + await vault.lockup(context, expiry, expiry) + await advanceTimeTo(expiry) + await expect(vault.flow(context, sender, receiver, 2)).to.be.revertedWith( + "LockExpired" + ) + }) + describe("when a lock is set", async function () { let expiry let maximum