diff --git a/src/threadpool.cpp b/src/threadpool.cpp index 42f3693..5ef3217 100644 --- a/src/threadpool.cpp +++ b/src/threadpool.cpp @@ -51,7 +51,7 @@ void ThreadPool::spawn(int count) { void ThreadPool::join() { { std::unique_lock lock(mMutex); - mWaitingCondition.wait(lock, [&]() { return mWaitingWorkers == int(mWorkers.size()); }); + mWaitingCondition.wait(lock, [&]() { return mBusyWorkers == 0; }); mJoining = true; mTasksCondition.notify_all(); } @@ -66,6 +66,8 @@ void ThreadPool::join() { } void ThreadPool::run() { + ++mBusyWorkers; + scope_guard([&]() { --mBusyWorkers; }); while (runOne()) { } } @@ -81,24 +83,23 @@ bool ThreadPool::runOne() { std::function ThreadPool::dequeue() { std::unique_lock lock(mMutex); while (!mJoining) { + std::optional time; if (!mTasks.empty()) { - if (mTasks.top().time <= clock::now()) { + time = mTasks.top().time; + if (*time <= clock::now()) { auto func = std::move(mTasks.top().func); mTasks.pop(); return func; } - - ++mWaitingWorkers; - mWaitingCondition.notify_all(); - mTasksCondition.wait_until(lock, mTasks.top().time); - - } else { - ++mWaitingWorkers; - mWaitingCondition.notify_all(); - mTasksCondition.wait(lock); } - --mWaitingWorkers; + --mBusyWorkers; + scope_guard([&]() { ++mBusyWorkers; }); + mWaitingCondition.notify_all(); + if(time) + mTasksCondition.wait_until(lock, *time); + else + mTasksCondition.wait(lock); } return nullptr; } diff --git a/src/threadpool.hpp b/src/threadpool.hpp index 70885b5..640b3df 100644 --- a/src/threadpool.hpp +++ b/src/threadpool.hpp @@ -72,7 +72,7 @@ protected: std::function dequeue(); // returns null function if joining std::vector mWorkers; - int mWaitingWorkers = 0; + int mBusyWorkers = 0; std::atomic mJoining = false; struct Task {