From b62c72b5e1a348c2d47c9f3c31c1d712e37286b7 Mon Sep 17 00:00:00 2001 From: Eric <5089238+emizzle@users.noreply.github.com> Date: Wed, 18 Sep 2024 20:22:17 +1000 Subject: [PATCH] uses msg.sender for slot reservation address --- contracts/SlotReservations.sol | 13 +++-- contracts/TestSlotReservations.sol | 16 ++++++ test/SlotReservations.test.js | 80 ++++++++++++++++++++---------- test/examples.js | 3 -- 4 files changed, 77 insertions(+), 35 deletions(-) create mode 100644 contracts/TestSlotReservations.sol diff --git a/contracts/SlotReservations.sol b/contracts/SlotReservations.sol index d1bfd12..5117186 100644 --- a/contracts/SlotReservations.sol +++ b/contracts/SlotReservations.sol @@ -7,20 +7,19 @@ import "./Requests.sol"; contract SlotReservations { using EnumerableSet for EnumerableSet.AddressSet; - mapping(SlotId => EnumerableSet.AddressSet) private _reservations; + mapping(SlotId => EnumerableSet.AddressSet) internal _reservations; uint8 private constant _MAX_RESERVATIONS = 3; - function reserveSlot(SlotId slotId, address host) public returns (bool) { - require(canReserveSlot(slotId, host), "Reservation not allowed"); + function reserveSlot(SlotId slotId) public returns (bool) { + address host = msg.sender; + require(canReserveSlot(slotId), "Reservation not allowed"); // returns false if set already contains address return _reservations[slotId].add(host); } - function canReserveSlot( - SlotId slotId, - address host - ) public view returns (bool) { + function canReserveSlot(SlotId slotId) public view returns (bool) { + address host = msg.sender; return // TODO: add in check for address inside of expanding window (_reservations[slotId].length() < _MAX_RESERVATIONS) && diff --git a/contracts/TestSlotReservations.sol b/contracts/TestSlotReservations.sol new file mode 100644 index 0000000..edc751e --- /dev/null +++ b/contracts/TestSlotReservations.sol @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.23; + +import "./SlotReservations.sol"; + +contract TestSlotReservations is SlotReservations { + using EnumerableSet for EnumerableSet.AddressSet; + + function contains(SlotId slotId, address host) public view returns (bool) { + return _reservations[slotId].contains(host); + } + + function length(SlotId slotId) public view returns (uint256) { + return _reservations[slotId].length(); + } +} diff --git a/test/SlotReservations.test.js b/test/SlotReservations.test.js index d82f9f1..c2042c3 100644 --- a/test/SlotReservations.test.js +++ b/test/SlotReservations.test.js @@ -1,6 +1,6 @@ const { expect } = require("chai") const { ethers } = require("hardhat") -const { exampleRequest, exampleAddress } = require("./examples") +const { exampleRequest } = require("./examples") const { requestId, slotId } = require("./ids") describe("SlotReservations", function () { @@ -10,16 +10,13 @@ describe("SlotReservations", function () { let slot beforeEach(async function () { - let SlotReservations = await ethers.getContractFactory("SlotReservations") + let SlotReservations = await ethers.getContractFactory( + "TestSlotReservations" + ) reservations = await SlotReservations.deploy() - - provider = exampleAddress() - address1 = exampleAddress() - address2 = exampleAddress() - address3 = exampleAddress() + ;[provider, address1, address2, address3] = await ethers.getSigners() request = await exampleRequest() - request.client = exampleAddress() slot = { request: requestId(request), @@ -27,47 +24,80 @@ describe("SlotReservations", function () { } }) + function switchAccount(account) { + reservations = reservations.connect(account) + } + it("allows a slot to be reserved", async function () { - let reserved = await reservations.callStatic.reserveSlot( - slotId(slot), - provider - ) + let id = slotId(slot) + let reserved = await reservations.callStatic.reserveSlot(id) expect(reserved).to.be.true }) + it("contains the correct addresses after reservation", async function () { + let id = slotId(slot) + await reservations.reserveSlot(id) + expect(await reservations.contains(id, provider.address)).to.be.true + + switchAccount(address1) + await reservations.reserveSlot(id) + expect(await reservations.contains(id, address1.address)).to.be.true + }) + + it("has the correct number of addresses after reservation", async function () { + let id = slotId(slot) + await reservations.reserveSlot(id) + expect(await reservations.length(id)).to.equal(1) + + switchAccount(address1) + await reservations.reserveSlot(id) + expect(await reservations.length(id)).to.equal(2) + }) + it("reports a slot can be reserved", async function () { - expect(await reservations.canReserveSlot(slotId(slot), provider)).to.be.true + expect(await reservations.canReserveSlot(slotId(slot))).to.be.true }) it("cannot reserve a slot more than once", async function () { let id = slotId(slot) - await reservations.reserveSlot(id, provider) - await expect(reservations.reserveSlot(id, provider)).to.be.revertedWith( + await reservations.reserveSlot(id) + await expect(reservations.reserveSlot(id)).to.be.revertedWith( "Reservation not allowed" ) + expect(await reservations.length(id)).to.equal(1) }) it("reports a slot cannot be reserved if already reserved", async function () { let id = slotId(slot) - await reservations.reserveSlot(id, provider) - expect(await reservations.canReserveSlot(id, provider)).to.be.false + await reservations.reserveSlot(id) + expect(await reservations.canReserveSlot(id)).to.be.false }) it("cannot reserve a slot if reservations are at capacity", async function () { let id = slotId(slot) - await reservations.reserveSlot(id, address1) - await reservations.reserveSlot(id, address2) - await reservations.reserveSlot(id, address3) - await expect(reservations.reserveSlot(id, provider)).to.be.revertedWith( + switchAccount(address1) + await reservations.reserveSlot(id) + switchAccount(address2) + await reservations.reserveSlot(id) + switchAccount(address3) + await reservations.reserveSlot(id) + switchAccount(provider) + await expect(reservations.reserveSlot(id)).to.be.revertedWith( "Reservation not allowed" ) + expect(await reservations.length(id)).to.equal(3) + expect(await reservations.contains(id, provider.address)).to.be.false }) it("reports a slot cannot be reserved if reservations are at capacity", async function () { let id = slotId(slot) - await reservations.reserveSlot(id, address1) - await reservations.reserveSlot(id, address2) - await reservations.reserveSlot(id, address3) - expect(await reservations.canReserveSlot(id, provider)).to.be.false + switchAccount(address1) + await reservations.reserveSlot(id) + switchAccount(address2) + await reservations.reserveSlot(id) + switchAccount(address3) + await reservations.reserveSlot(id) + switchAccount(provider) + expect(await reservations.canReserveSlot(id)).to.be.false }) }) diff --git a/test/examples.js b/test/examples.js index 73900a3..06d8428 100644 --- a/test/examples.js +++ b/test/examples.js @@ -51,12 +51,9 @@ const invalidProof = () => ({ c: { x: 0, y: 0 }, }) -const exampleAddress = () => hexlify(randomBytes(20)) - module.exports = { exampleConfiguration, exampleRequest, exampleProof, invalidProof, - exampleAddress, }