quickpool  0.2.0
An easy-to-use, header-only work stealing thread pool in C++11
quickpool.hpp
1 // Copyright 2021 Thomas Nagler (MIT License)
2 //
3 // Permission is hereby granted, free of charge, to any person obtaining a copy
4 // of this software and associated documentation files (the "Software"), to deal
5 // in the Software without restriction, including without limitation the rights
6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 // copies of the Software, and to permit persons to whom the Software is
8 // furnished to do so, subject to the following conditions:
9 
10 // The above copyright notice and this permission notice shall be included in
11 // all copies or substantial portions of the Software.
12 
13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 // SOFTWARE.
20 
21 #include <atomic>
22 #include <condition_variable>
23 #include <exception>
24 #include <functional>
25 #include <future>
26 #include <mutex>
27 #include <thread>
28 #include <vector>
29 
31 namespace quickpool {
32 
36 class TodoList
37 {
38  public:
41  TodoList(size_t num_tasks = 0) noexcept
42  : num_tasks_(num_tasks)
43  {}
44 
47  void add(size_t num_tasks = 1) noexcept { num_tasks_.fetch_add(num_tasks); }
48 
51  void cross(size_t num_tasks = 1)
52  {
53  num_tasks_.fetch_sub(num_tasks);
54  if (num_tasks_ <= 0) {
55  {
56  std::lock_guard<std::mutex> lk(mtx_); // must lock before signal
57  }
58  cv_.notify_all();
59  }
60  }
61 
63  bool empty() const noexcept { return num_tasks_ <= 0; }
64 
68  void wait(size_t millis = 0)
69  {
70  std::this_thread::yield();
71  auto wake_up = [this] { return (num_tasks_ <= 0) || exception_ptr_; };
72  std::unique_lock<std::mutex> lk(mtx_);
73  if (millis == 0) {
74  cv_.wait(lk, wake_up);
75  } else {
76  cv_.wait_for(lk, std::chrono::milliseconds(millis), wake_up);
77  }
78  if (exception_ptr_)
79  std::rethrow_exception(exception_ptr_);
80  }
81 
85  void stop(std::exception_ptr eptr = nullptr) noexcept
86  {
87  {
88  std::lock_guard<std::mutex> lk(mtx_);
89  // Some threads may add() or cross() after we stop. The large
90  // negative number prevents num_tasks_ from becoming positive again.
91  num_tasks_ = std::numeric_limits<int>::min() / 2;
92  exception_ptr_ = eptr;
93  }
94  cv_.notify_all();
95  }
96 
97  private:
98  alignas(64) std::atomic_int num_tasks_{ 0 };
99  std::mutex mtx_;
100  std::condition_variable cv_;
101  std::exception_ptr exception_ptr_{ nullptr };
102 };
103 
105 namespace detail {
106 
108 template<typename T>
109 class RingBuffer
110 {
111  public:
112  explicit RingBuffer(size_t capacity)
113  : buffer_{ std::unique_ptr<T[]>(new T[capacity]) }
114  , capacity_{ capacity }
115  , mask_{ capacity - 1 }
116 
117  {}
118 
119  size_t capacity() const { return capacity_; }
120 
121  void set_entry(size_t i, T val) { buffer_[i & mask_] = val; }
122 
123  T get_entry(size_t i) const { return buffer_[i & mask_]; }
124 
125  RingBuffer<T>* enlarged_copy(size_t bottom, size_t top) const
126  {
127  RingBuffer<T>* new_buffer = new RingBuffer{ 2 * capacity_ };
128  for (size_t i = top; i != bottom; ++i)
129  new_buffer->set_entry(i, this->get_entry(i));
130  return new_buffer;
131  }
132 
133  private:
134  std::unique_ptr<T[]> buffer_;
135  size_t capacity_;
136  size_t mask_;
137 };
138 
139 // exchange is not available in C++11, use implementatino from
140 // https://en.cppreference.com/w/cpp/utility/exchange
141 template<class T>
142 T
143 exchange(T& obj, T&& new_value) noexcept
144 {
145  T old_value = std::move(obj);
146  obj = std::forward<T>(new_value);
147  return old_value;
148 }
149 
151 class TaskQueue
152 {
153  using Task = std::function<void()>;
154 
155  public:
158  TaskQueue(size_t capacity = 256)
159  : buffer_{ new RingBuffer<Task*>(capacity) }
160  {}
161 
162  ~TaskQueue() noexcept
163  {
164  // must free memory allocated by push(), but not deallocated by pop()
165  auto buf_ptr = buffer_.load();
166  for (int i = top_; i < bottom_.load(m_relaxed); ++i)
167  delete buf_ptr->get_entry(i);
168  delete buf_ptr;
169  }
170 
171  TaskQueue(TaskQueue const& other) = delete;
172  TaskQueue& operator=(TaskQueue const& other) = delete;
173 
175  bool empty() const
176  {
177  return (bottom_.load(m_relaxed) <= top_.load(m_relaxed));
178  }
179 
182  bool try_push(Task&& task)
183  {
184  {
185  // must hold lock in case of multiple producers, abort if already
186  // taken, so we can check out next queue
187  std::unique_lock<std::mutex> lk(mutex_, std::try_to_lock);
188  if (!lk)
189  return false;
190  this->push_unsafe(std::forward<Task>(task));
191  }
192  cv_.notify_one();
193  return true;
194  }
195 
197  void force_push(Task&& task)
198  {
199  {
200  // must hold lock in case of multiple producers
201  std::lock_guard<std::mutex> lk(mutex_);
202  this->push_unsafe(std::forward<Task>(task));
203  }
204  cv_.notify_one();
205  }
206 
208  void push_unsafe(Task&& task)
209  {
210  auto b = bottom_.load(m_relaxed);
211  auto t = top_.load(m_acquire);
212  RingBuffer<Task*>* buf_ptr = buffer_.load(m_relaxed);
213 
214  if (static_cast<int>(buf_ptr->capacity()) < (b - t) + 1) {
215  // buffer is full, create enlarged copy before continuing
216  old_buffers_.emplace_back(
217  exchange(buf_ptr, buf_ptr->enlarged_copy(b, t)));
218  buffer_.store(buf_ptr, m_relaxed);
219  }
220 
221  buf_ptr->set_entry(b, new Task{ std::forward<Task>(task) });
222  bottom_.store(b + 1, m_release);
223  }
224 
226  bool try_pop(Task& task)
227  {
228  auto t = top_.load(m_acquire);
229  std::atomic_thread_fence(m_seq_cst);
230  auto b = bottom_.load(m_acquire);
231 
232  if (t < b) {
233  // must load task pointer before acquiring the slot, because it
234  // could be overwritten immediately after
235  auto task_ptr = buffer_.load(m_acquire)->get_entry(t);
236 
237  if (top_.compare_exchange_strong(t, t + 1, m_seq_cst, m_relaxed)) {
238  task = std::move(*task_ptr); // won race, get task
239  delete task_ptr; // fre memory allocated in push_unsafe()
240  return true;
241  }
242  }
243  return false; // queue is empty or lost race
244  }
245 
246  void wait()
247  {
248  std::unique_lock<std::mutex> lk(mutex_);
249  cv_.wait(lk, [this] { return !this->empty() || stopped_; });
250  }
251 
252  void stop()
253  {
254  {
255  std::lock_guard<std::mutex> lk(mutex_);
256  stopped_ = true;
257  }
258  cv_.notify_all();
259  }
260 
261  private:
262  alignas(64) std::atomic_int top_{ 0 };
263  alignas(64) std::atomic_int bottom_{ 0 };
264  alignas(64) std::atomic<RingBuffer<Task*>*> buffer_{ nullptr };
265  std::vector<std::unique_ptr<RingBuffer<Task*>>> old_buffers_;
266  std::mutex mutex_;
267  std::condition_variable cv_;
268  std::atomic<bool> stopped_;
269 
270  // convenience aliases
271  static constexpr std::memory_order m_relaxed = std::memory_order_relaxed;
272  static constexpr std::memory_order m_acquire = std::memory_order_acquire;
273  static constexpr std::memory_order m_release = std::memory_order_release;
274  static constexpr std::memory_order m_seq_cst = std::memory_order_seq_cst;
275 };
276 
278 struct TaskManager
279 {
280  std::vector<TaskQueue> queues_;
281  size_t num_queues_;
282  alignas(64) std::atomic_size_t push_idx_{ 0 };
283  std::atomic_bool stopped_{ false };
284  std::atomic_size_t todo_list_{ 0 };
285 
286  explicit TaskManager(size_t num_queues)
287  : queues_{ std::vector<TaskQueue>(num_queues) }
288  , num_queues_{ num_queues }
289  {}
290 
291  template<typename Task>
292  void push(Task&& task)
293  {
294  if (stopped_)
295  return;
296  for (size_t count = 0; count < num_queues_ * 20; count++) {
297  if (queues_[push_idx_++ % num_queues_].try_push(task))
298  return;
299  }
300  queues_[push_idx_++ % num_queues_].force_push(task);
301  }
302 
303  template<typename Task>
304  bool try_pop(Task& task, size_t worker_id = 0)
305  {
306  if (stopped_)
307  return false;
308  for (size_t k = 0; k <= num_queues_; k++) {
309  if (queues_[(worker_id + k) % num_queues_].try_pop(task))
310  return true;
311  }
312  return false;
313  }
314 
315  void wait_for_jobs(size_t id) { queues_[id].wait(); }
316 
317  void stop()
318  {
319  for (auto& q : queues_)
320  q.stop();
321  stopped_ = true;
322  }
323 
324  bool stopped() { return stopped_; }
325 };
326 
327 } // end namespace detail
328 
331 {
332  public:
335  : ThreadPool(std::thread::hardware_concurrency())
336  {}
337 
341  explicit ThreadPool(size_t n_workers)
342  : task_manager_{ n_workers }
343  {
344  for (size_t id = 0; id < n_workers; ++id) {
345  workers_.emplace_back([this, id] {
346  std::function<void()> task;
347  while (!task_manager_.stopped()) {
348  task_manager_.wait_for_jobs(id);
349  do {
350  // inner while to save a few cash misses calling empty()
351  if (task_manager_.try_pop(task, id))
352  this->execute_safely(task);
353  } while (!todo_list_.empty());
354  }
355  });
356  }
357  }
358 
359  ~ThreadPool()
360  {
361  task_manager_.stop();
362  for (auto& worker : workers_) {
363  if (worker.joinable())
364  worker.join();
365  }
366  }
367 
368  ThreadPool(ThreadPool&&) = delete;
369  ThreadPool(const ThreadPool&) = delete;
370  ThreadPool& operator=(const ThreadPool&) = delete;
371  ThreadPool& operator=(ThreadPool&& other) = delete;
372 
375  {
376  static ThreadPool instance_;
377  return instance_;
378  }
379 
383  template<class Function, class... Args>
384  void push(Function&& f, Args&&... args)
385  {
386  if (workers_.size() == 0)
387  return f(args...);
388  todo_list_.add();
389  task_manager_.push(
390  std::bind(std::forward<Function>(f), std::forward<Args>(args)...));
391  }
392 
398  template<class Function, class... Args>
399  auto async(Function&& f, Args&&... args)
400  -> std::future<decltype(f(args...))>
401  {
402  auto pack =
403  std::bind(std::forward<Function>(f), std::forward<Args>(args)...);
404  using pack_t = std::packaged_task<decltype(f(args...))()>;
405  auto task_ptr = std::make_shared<pack_t>(std::move(pack));
406  this->push([task_ptr] { (*task_ptr)(); });
407  return task_ptr->get_future();
408  }
409 
411  void wait() { todo_list_.wait(); }
412 
413  private:
414  void execute_safely(std::function<void()>& task)
415  {
416  try {
417  task();
418  todo_list_.cross();
419  } catch (...) {
420  todo_list_.stop(std::current_exception());
421  task_manager_.stop();
422  }
423  }
424 
425  detail::TaskManager task_manager_;
426  TodoList todo_list_{ 0 };
427  std::vector<std::thread> workers_;
428 };
429 
431 
435 template<class Function, class... Args>
436 void
437 push(Function&& f, Args&&... args)
438 {
439  ThreadPool::global_instance().push(std::forward<Function>(f),
440  std::forward<Args>(args)...);
441 }
442 
448 template<class Function, class... Args>
449 auto
450 async(Function&& f, Args&&... args) -> std::future<decltype(f(args...))>
451 {
452  return ThreadPool::global_instance().async(std::forward<Function>(f),
453  std::forward<Args>(args)...);
454 }
455 
457 inline void
459 {
461 }
462 
463 } // end namespace quickpool
quickpool
quickpool namespace
Definition: quickpool.hpp:31
quickpool::ThreadPool::ThreadPool
ThreadPool()
constructs a thread pool with as many workers as there are cores.
Definition: quickpool.hpp:334
quickpool::TodoList::TodoList
TodoList(size_t num_tasks=0) noexcept
Definition: quickpool.hpp:41
quickpool::async
auto async(Function &&f, Args &&... args) -> std::future< decltype(f(args...))>
executes a job asynchronously the global thread pool.
Definition: quickpool.hpp:450
quickpool::ThreadPool::wait
void wait()
waits for all jobs currently running on the global thread pool.
Definition: quickpool.hpp:411
quickpool::TodoList::add
void add(size_t num_tasks=1) noexcept
Definition: quickpool.hpp:47
quickpool::ThreadPool::global_instance
static ThreadPool & global_instance()
returns a reference to the global thread pool instance.
Definition: quickpool.hpp:374
quickpool::wait
void wait()
waits for all jobs currently running on the global thread pool.
Definition: quickpool.hpp:458
quickpool::TodoList::stop
void stop(std::exception_ptr eptr=nullptr) noexcept
Definition: quickpool.hpp:85
quickpool::TodoList::cross
void cross(size_t num_tasks=1)
Definition: quickpool.hpp:51
quickpool::TodoList
Todo list - a synchronization primitive.
Definition: quickpool.hpp:36
quickpool::ThreadPool::ThreadPool
ThreadPool(size_t n_workers)
Definition: quickpool.hpp:341
quickpool::TodoList::empty
bool empty() const noexcept
checks whether list is empty.
Definition: quickpool.hpp:63
quickpool::ThreadPool::async
auto async(Function &&f, Args &&... args) -> std::future< decltype(f(args...))>
executes a job asynchronously the global thread pool.
Definition: quickpool.hpp:399
quickpool::ThreadPool
A work stealing thread pool.
Definition: quickpool.hpp:330
quickpool::ThreadPool::push
void push(Function &&f, Args &&... args)
pushes a job to the thread pool.
Definition: quickpool.hpp:384
quickpool::TodoList::wait
void wait(size_t millis=0)
Definition: quickpool.hpp:68
quickpool::push
void push(Function &&f, Args &&... args)
Direct access to the global thread pool ----------------—.
Definition: quickpool.hpp:437