diff --git a/contracts/vault/Locks.sol b/contracts/vault/Locks.sol index 276a3f1..b0742ab 100644 --- a/contracts/vault/Locks.sol +++ b/contracts/vault/Locks.sol @@ -6,6 +6,7 @@ import "./Timestamps.sol"; struct Lock { Timestamp expiry; Timestamp maximum; + uint128 value; } library Locks { diff --git a/contracts/vault/VaultBase.sol b/contracts/vault/VaultBase.sol index c79dd7c..313059e 100644 --- a/contracts/vault/VaultBase.sol +++ b/contracts/vault/VaultBase.sol @@ -81,6 +81,7 @@ abstract contract VaultBase { ) internal { Recipient recipient = Recipient.wrap(from); _balances[controller][context][recipient].available += amount; + _locks[controller][context].value += amount; _token.safeTransferFrom(from, address(this), amount); } @@ -98,10 +99,20 @@ abstract contract VaultBase { Context context, Recipient recipient ) internal { - require(!_locks[controller][context].isLocked(), Locked()); - delete _locks[controller][context]; + Lock memory lock = _locks[controller][context]; + require(!lock.isLocked(), Locked()); + Balance memory balance = _getBalance(controller, context, recipient); uint128 amount = balance.available + balance.designated; + + lock.value -= amount; + + if (lock.value == 0) { + delete _locks[controller][context]; + } else { + _locks[controller][context] = lock; + } + _delete(controller, context, recipient); _token.safeTransfer(Recipient.unwrap(recipient), amount); } @@ -111,12 +122,22 @@ abstract contract VaultBase { Context context, Recipient recipient ) internal { + Lock memory lock = _locks[controller][context]; + Flow memory flow = _flows[controller][context][recipient]; require(flow.rate == TokensPerSecond.wrap(0), CannotBurnFlowingTokens()); Balance memory balance = _getBalance(controller, context, recipient); uint128 amount = balance.available + balance.designated; + lock.value -= amount; + + if (lock.value == 0) { + delete _locks[controller][context]; + } else { + _locks[controller][context] = lock; + } + _delete(controller, context, recipient); _token.safeTransfer(address(0xdead), amount); diff --git a/test/Vault.tests.js b/test/Vault.tests.js index 87af047..faf65f1 100644 --- a/test/Vault.tests.js +++ b/test/Vault.tests.js @@ -391,10 +391,26 @@ describe("Vault", function () { expect(await vault.getBalance(context, account.address)).to.equal(0) }) - it("deletes lock when funds are withdrawn", async function () { + it("deletes lock when all tokens are withdrawn/burned", async function () { await vault.lock(context, expiry, expiry) + await vault.transfer(context, account.address, account2.address, 20) + await vault.transfer(context, account2.address, account3.address, 10) + + // part of the tokens are burned + await vault.burn(context, account2.address) await advanceTimeTo(expiry) + expect((await vault.getLock(context))[0]).not.to.equal(0) + expect((await vault.getLock(context))[1]).not.to.equal(0) + + // part of the tokens are withdrawn await vault.withdraw(context, account.address) + expect((await vault.getLock(context))[0]).not.to.equal(0) + expect((await vault.getLock(context))[1]).not.to.equal(0) + + // remainder of the tokens are withdrawn by recipient + await vault + .connect(account3) + .withdrawByRecipient(controller.address, context) expect((await vault.getLock(context))[0]).to.equal(0) expect((await vault.getLock(context))[1]).to.equal(0) })