314 lines
7.5 KiB
C++
314 lines
7.5 KiB
C++
|
// Copyright 2004-present Facebook. All Rights Reserved.
|
||
|
|
||
|
#include "CxxMessageQueue.h"
|
||
|
|
||
|
#include <folly/AtomicIntrusiveLinkedList.h>
|
||
|
|
||
|
#include <unordered_map>
|
||
|
#include <mutex>
|
||
|
#include <queue>
|
||
|
|
||
|
#include <glog/logging.h>
|
||
|
|
||
|
namespace facebook {
|
||
|
namespace react {
|
||
|
|
||
|
using detail::BinarySemaphore;
|
||
|
using detail::EventFlag;
|
||
|
|
||
|
using clock = std::chrono::steady_clock;
|
||
|
using time_point = clock::time_point;
|
||
|
static_assert(std::is_same<time_point, EventFlag::time_point>::value, "");
|
||
|
|
||
|
namespace {
|
||
|
time_point now() {
|
||
|
return clock::now();
|
||
|
}
|
||
|
|
||
|
class Task {
|
||
|
public:
|
||
|
static Task* create(std::function<void()>&& func) {
|
||
|
return new Task{std::move(func), false, time_point()};
|
||
|
}
|
||
|
|
||
|
static Task* createSync(std::function<void()>&& func) {
|
||
|
return new Task{std::move(func), true, time_point()};
|
||
|
}
|
||
|
|
||
|
static Task* createDelayed(std::function<void()>&& func, time_point startTime) {
|
||
|
return new Task{std::move(func), false, startTime};
|
||
|
}
|
||
|
|
||
|
std::function<void()> func;
|
||
|
// This flag is just to mark that the task is expected to be synchronous. If
|
||
|
// a synchronous task races with stopping the queue, the thread waiting on
|
||
|
// the synchronous task might never resume. We use this flag to detect this
|
||
|
// case and throw an error.
|
||
|
bool sync;
|
||
|
time_point startTime;
|
||
|
|
||
|
folly::AtomicIntrusiveLinkedListHook<Task> hook;
|
||
|
|
||
|
// Should this sort consider id also?
|
||
|
struct Compare {
|
||
|
bool operator()(const Task* a, const Task* b) {
|
||
|
return a->startTime > b->startTime;
|
||
|
}
|
||
|
};
|
||
|
};
|
||
|
|
||
|
class DelayedTaskQueue {
|
||
|
public:
|
||
|
~DelayedTaskQueue() {
|
||
|
while (!queue_.empty()) {
|
||
|
delete queue_.top();
|
||
|
queue_.pop();
|
||
|
}
|
||
|
}
|
||
|
|
||
|
void process() {
|
||
|
while (!queue_.empty()) {
|
||
|
Task* d = queue_.top();
|
||
|
if (now() < d->startTime) {
|
||
|
break;
|
||
|
}
|
||
|
auto owned = std::unique_ptr<Task>(queue_.top());
|
||
|
queue_.pop();
|
||
|
owned->func();
|
||
|
}
|
||
|
}
|
||
|
|
||
|
void push(Task* t) {
|
||
|
queue_.push(t);
|
||
|
}
|
||
|
|
||
|
bool empty() {
|
||
|
return queue_.empty();
|
||
|
}
|
||
|
|
||
|
time_point nextTime() {
|
||
|
return queue_.top()->startTime;
|
||
|
}
|
||
|
private:
|
||
|
std::priority_queue<Task*, std::vector<Task*>, Task::Compare> queue_;
|
||
|
};
|
||
|
|
||
|
}
|
||
|
|
||
|
class CxxMessageQueue::QueueRunner {
|
||
|
public:
|
||
|
~QueueRunner() {
|
||
|
queue_.sweep([] (Task* t) {
|
||
|
delete t;
|
||
|
});
|
||
|
}
|
||
|
|
||
|
void enqueue(std::function<void()>&& func) {
|
||
|
enqueueTask(Task::create(std::move(func)));
|
||
|
}
|
||
|
|
||
|
void enqueueDelayed(std::function<void()>&& func, uint64_t delayMs) {
|
||
|
if (delayMs) {
|
||
|
enqueueTask(Task::createDelayed(std::move(func), now() + std::chrono::milliseconds(delayMs)));
|
||
|
} else {
|
||
|
enqueue(std::move(func));
|
||
|
}
|
||
|
}
|
||
|
|
||
|
void enqueueSync(std::function<void()>&& func) {
|
||
|
EventFlag done;
|
||
|
enqueueTask(Task::createSync([&] () mutable {
|
||
|
func();
|
||
|
done.set();
|
||
|
}));
|
||
|
if (stopped_) {
|
||
|
// If this queue is stopped_, the sync task might never actually run.
|
||
|
throw std::runtime_error("Stopped within enqueueSync.");
|
||
|
}
|
||
|
done.wait();
|
||
|
}
|
||
|
|
||
|
void stop() {
|
||
|
stopped_ = true;
|
||
|
pending_.set();
|
||
|
}
|
||
|
|
||
|
bool isStopped() {
|
||
|
return stopped_;
|
||
|
}
|
||
|
|
||
|
void quitSynchronous() {
|
||
|
stop();
|
||
|
finished_.wait();
|
||
|
}
|
||
|
|
||
|
void run() {
|
||
|
// If another thread stops this one, then the acquire-release on pending_
|
||
|
// ensures that we read stopped some time after it was set (and other
|
||
|
// threads just have to deal with the fact that we might run a task "after"
|
||
|
// they stop us).
|
||
|
//
|
||
|
// If we are stopped on this thread, then memory order doesn't really
|
||
|
// matter reading stopped_.
|
||
|
while (!stopped_.load(std::memory_order_relaxed)) {
|
||
|
sweep();
|
||
|
if (delayed_.empty()) {
|
||
|
pending_.wait();
|
||
|
} else {
|
||
|
pending_.wait_until(delayed_.nextTime());
|
||
|
}
|
||
|
}
|
||
|
// This sweep is just to catch erroneous enqueueSync. That is, there could
|
||
|
// be a task marked sync that another thread is waiting for, but we'll
|
||
|
// never actually run it.
|
||
|
sweep();
|
||
|
finished_.set();
|
||
|
}
|
||
|
|
||
|
// We are processing two queues, the posted tasks (queue_) and the delayed
|
||
|
// tasks (delayed_). Delayed tasks first go into posted tasks, and then are
|
||
|
// moved to the delayed queue if we pop them before the time they are
|
||
|
// scheduled for.
|
||
|
// As we pop things from queue_, before dealing with that thing, we run any
|
||
|
// delayed tasks whose scheduled time has arrived.
|
||
|
void sweep() {
|
||
|
queue_.sweep([this] (Task* t) {
|
||
|
std::unique_ptr<Task> owned(t);
|
||
|
if (stopped_.load(std::memory_order_relaxed)) {
|
||
|
if (t->sync) {
|
||
|
throw std::runtime_error("Sync task posted while stopped.");
|
||
|
}
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
delayed_.process();
|
||
|
if (t->startTime != time_point() && now() <= t->startTime) {
|
||
|
delayed_.push(owned.release());
|
||
|
} else {
|
||
|
t->func();
|
||
|
}
|
||
|
});
|
||
|
delayed_.process();
|
||
|
}
|
||
|
|
||
|
void bindToThisThread() {
|
||
|
if (tid_ != std::thread::id{}) {
|
||
|
throw std::runtime_error("Message queue already bound to thread.");
|
||
|
}
|
||
|
tid_ = std::this_thread::get_id();
|
||
|
}
|
||
|
|
||
|
bool isOnQueue() {
|
||
|
return std::this_thread::get_id() == tid_;
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
void enqueueTask(Task* task) {
|
||
|
if (queue_.insertHead(task)) {
|
||
|
pending_.set();
|
||
|
}
|
||
|
}
|
||
|
|
||
|
std::thread::id tid_;
|
||
|
|
||
|
folly::AtomicIntrusiveLinkedList<Task, &Task::hook> queue_;
|
||
|
|
||
|
std::atomic_bool stopped_{false};
|
||
|
DelayedTaskQueue delayed_;
|
||
|
|
||
|
BinarySemaphore pending_;
|
||
|
EventFlag finished_;
|
||
|
};
|
||
|
|
||
|
|
||
|
CxxMessageQueue::CxxMessageQueue() : qr_(new QueueRunner()) {
|
||
|
|
||
|
}
|
||
|
|
||
|
CxxMessageQueue::~CxxMessageQueue() {
|
||
|
// TODO(cjhopman): Add detach() so that the queue doesn't have to be
|
||
|
// explicitly stopped.
|
||
|
if (!qr_->isStopped()) {
|
||
|
LOG(FATAL) << "Queue not stopped.";
|
||
|
}
|
||
|
}
|
||
|
|
||
|
void CxxMessageQueue::runOnQueue(std::function<void()>&& func) {
|
||
|
qr_->enqueue(std::move(func));
|
||
|
}
|
||
|
|
||
|
void CxxMessageQueue::runOnQueueDelayed(std::function<void()>&& func, uint64_t delayMs) {
|
||
|
qr_->enqueueDelayed(std::move(func), delayMs);
|
||
|
}
|
||
|
|
||
|
void CxxMessageQueue::runOnQueueSync(std::function<void()>&& func) {
|
||
|
if (isOnQueue()) {
|
||
|
func();
|
||
|
return;
|
||
|
}
|
||
|
qr_->enqueueSync(std::move(func));
|
||
|
}
|
||
|
|
||
|
void CxxMessageQueue::quitSynchronous() {
|
||
|
if (isOnQueue()) {
|
||
|
qr_->stop();
|
||
|
} else {
|
||
|
qr_->quitSynchronous();
|
||
|
}
|
||
|
}
|
||
|
|
||
|
bool CxxMessageQueue::isOnQueue() {
|
||
|
return qr_->isOnQueue();
|
||
|
}
|
||
|
|
||
|
namespace {
|
||
|
struct MQRegistry {
|
||
|
std::weak_ptr<CxxMessageQueue> find(std::thread::id tid) {
|
||
|
std::lock_guard<std::mutex> g(lock_);
|
||
|
auto iter = registry_.find(tid);
|
||
|
if (iter == registry_.end()) return std::weak_ptr<CxxMessageQueue>();
|
||
|
return iter->second;
|
||
|
}
|
||
|
|
||
|
void registerQueue(std::thread::id tid, std::weak_ptr<CxxMessageQueue> mq) {
|
||
|
std::lock_guard<std::mutex> g(lock_);
|
||
|
registry_[tid] = mq;
|
||
|
}
|
||
|
|
||
|
void unregister(std::thread::id tid) {
|
||
|
std::lock_guard<std::mutex> g(lock_);
|
||
|
registry_.erase(tid);
|
||
|
}
|
||
|
private:
|
||
|
std::mutex lock_;
|
||
|
std::unordered_map<std::thread::id, std::weak_ptr<CxxMessageQueue>> registry_;
|
||
|
};
|
||
|
|
||
|
MQRegistry& getMQRegistry() {
|
||
|
static MQRegistry* mq_registry = new MQRegistry();
|
||
|
return *mq_registry;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
std::weak_ptr<CxxMessageQueue> CxxMessageQueue::current() {
|
||
|
auto tid = std::this_thread::get_id();
|
||
|
return getMQRegistry().find(tid);
|
||
|
}
|
||
|
|
||
|
std::function<void()> CxxMessageQueue::getRunLoop(std::shared_ptr<CxxMessageQueue> mq) {
|
||
|
return [capture=mq->qr_, weakMq=std::weak_ptr<CxxMessageQueue>(mq)] {
|
||
|
capture->bindToThisThread();
|
||
|
auto tid = std::this_thread::get_id();
|
||
|
|
||
|
// TODO: handle nested runloops (either allow them or throw an exception).
|
||
|
getMQRegistry().registerQueue(tid, weakMq);
|
||
|
capture->run();
|
||
|
getMQRegistry().unregister(tid);
|
||
|
};
|
||
|
}
|
||
|
|
||
|
|
||
|
|
||
|
} // namespace react
|
||
|
} // namespace facebook
|