uses msg.sender for slot reservation address

This commit is contained in:
Eric 2024-09-18 20:22:17 +10:00
parent 64ce222e24
commit b62c72b5e1
No known key found for this signature in database
4 changed files with 77 additions and 35 deletions

View File

@ -7,20 +7,19 @@ import "./Requests.sol";
contract SlotReservations {
using EnumerableSet for EnumerableSet.AddressSet;
mapping(SlotId => EnumerableSet.AddressSet) private _reservations;
mapping(SlotId => EnumerableSet.AddressSet) internal _reservations;
uint8 private constant _MAX_RESERVATIONS = 3;
function reserveSlot(SlotId slotId, address host) public returns (bool) {
require(canReserveSlot(slotId, host), "Reservation not allowed");
function reserveSlot(SlotId slotId) public returns (bool) {
address host = msg.sender;
require(canReserveSlot(slotId), "Reservation not allowed");
// returns false if set already contains address
return _reservations[slotId].add(host);
}
function canReserveSlot(
SlotId slotId,
address host
) public view returns (bool) {
function canReserveSlot(SlotId slotId) public view returns (bool) {
address host = msg.sender;
return
// TODO: add in check for address inside of expanding window
(_reservations[slotId].length() < _MAX_RESERVATIONS) &&

View File

@ -0,0 +1,16 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.23;
import "./SlotReservations.sol";
contract TestSlotReservations is SlotReservations {
using EnumerableSet for EnumerableSet.AddressSet;
function contains(SlotId slotId, address host) public view returns (bool) {
return _reservations[slotId].contains(host);
}
function length(SlotId slotId) public view returns (uint256) {
return _reservations[slotId].length();
}
}

View File

@ -1,6 +1,6 @@
const { expect } = require("chai")
const { ethers } = require("hardhat")
const { exampleRequest, exampleAddress } = require("./examples")
const { exampleRequest } = require("./examples")
const { requestId, slotId } = require("./ids")
describe("SlotReservations", function () {
@ -10,16 +10,13 @@ describe("SlotReservations", function () {
let slot
beforeEach(async function () {
let SlotReservations = await ethers.getContractFactory("SlotReservations")
let SlotReservations = await ethers.getContractFactory(
"TestSlotReservations"
)
reservations = await SlotReservations.deploy()
provider = exampleAddress()
address1 = exampleAddress()
address2 = exampleAddress()
address3 = exampleAddress()
;[provider, address1, address2, address3] = await ethers.getSigners()
request = await exampleRequest()
request.client = exampleAddress()
slot = {
request: requestId(request),
@ -27,47 +24,80 @@ describe("SlotReservations", function () {
}
})
function switchAccount(account) {
reservations = reservations.connect(account)
}
it("allows a slot to be reserved", async function () {
let reserved = await reservations.callStatic.reserveSlot(
slotId(slot),
provider
)
let id = slotId(slot)
let reserved = await reservations.callStatic.reserveSlot(id)
expect(reserved).to.be.true
})
it("contains the correct addresses after reservation", async function () {
let id = slotId(slot)
await reservations.reserveSlot(id)
expect(await reservations.contains(id, provider.address)).to.be.true
switchAccount(address1)
await reservations.reserveSlot(id)
expect(await reservations.contains(id, address1.address)).to.be.true
})
it("has the correct number of addresses after reservation", async function () {
let id = slotId(slot)
await reservations.reserveSlot(id)
expect(await reservations.length(id)).to.equal(1)
switchAccount(address1)
await reservations.reserveSlot(id)
expect(await reservations.length(id)).to.equal(2)
})
it("reports a slot can be reserved", async function () {
expect(await reservations.canReserveSlot(slotId(slot), provider)).to.be.true
expect(await reservations.canReserveSlot(slotId(slot))).to.be.true
})
it("cannot reserve a slot more than once", async function () {
let id = slotId(slot)
await reservations.reserveSlot(id, provider)
await expect(reservations.reserveSlot(id, provider)).to.be.revertedWith(
await reservations.reserveSlot(id)
await expect(reservations.reserveSlot(id)).to.be.revertedWith(
"Reservation not allowed"
)
expect(await reservations.length(id)).to.equal(1)
})
it("reports a slot cannot be reserved if already reserved", async function () {
let id = slotId(slot)
await reservations.reserveSlot(id, provider)
expect(await reservations.canReserveSlot(id, provider)).to.be.false
await reservations.reserveSlot(id)
expect(await reservations.canReserveSlot(id)).to.be.false
})
it("cannot reserve a slot if reservations are at capacity", async function () {
let id = slotId(slot)
await reservations.reserveSlot(id, address1)
await reservations.reserveSlot(id, address2)
await reservations.reserveSlot(id, address3)
await expect(reservations.reserveSlot(id, provider)).to.be.revertedWith(
switchAccount(address1)
await reservations.reserveSlot(id)
switchAccount(address2)
await reservations.reserveSlot(id)
switchAccount(address3)
await reservations.reserveSlot(id)
switchAccount(provider)
await expect(reservations.reserveSlot(id)).to.be.revertedWith(
"Reservation not allowed"
)
expect(await reservations.length(id)).to.equal(3)
expect(await reservations.contains(id, provider.address)).to.be.false
})
it("reports a slot cannot be reserved if reservations are at capacity", async function () {
let id = slotId(slot)
await reservations.reserveSlot(id, address1)
await reservations.reserveSlot(id, address2)
await reservations.reserveSlot(id, address3)
expect(await reservations.canReserveSlot(id, provider)).to.be.false
switchAccount(address1)
await reservations.reserveSlot(id)
switchAccount(address2)
await reservations.reserveSlot(id)
switchAccount(address3)
await reservations.reserveSlot(id)
switchAccount(provider)
expect(await reservations.canReserveSlot(id)).to.be.false
})
})

View File

@ -51,12 +51,9 @@ const invalidProof = () => ({
c: { x: 0, y: 0 },
})
const exampleAddress = () => hexlify(randomBytes(20))
module.exports = {
exampleConfiguration,
exampleRequest,
exampleProof,
invalidProof,
exampleAddress,
}