Parallel.h revision 321369
1//===- llvm/Support/Parallel.h - Parallel algorithms ----------------------===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9
10#ifndef LLVM_SUPPORT_PARALLEL_H
11#define LLVM_SUPPORT_PARALLEL_H
12
13#include "llvm/ADT/STLExtras.h"
14#include "llvm/Config/llvm-config.h"
15#include "llvm/Support/MathExtras.h"
16
17#include <algorithm>
18#include <condition_variable>
19#include <functional>
20#include <mutex>
21
22#if defined(_MSC_VER) && LLVM_ENABLE_THREADS
23#pragma warning(push)
24#pragma warning(disable : 4530)
25#include <concrt.h>
26#include <ppl.h>
27#pragma warning(pop)
28#endif
29
30namespace llvm {
31
32namespace parallel {
33struct sequential_execution_policy {};
34struct parallel_execution_policy {};
35
36template <typename T>
37struct is_execution_policy
38    : public std::integral_constant<
39          bool, llvm::is_one_of<T, sequential_execution_policy,
40                                parallel_execution_policy>::value> {};
41
42constexpr sequential_execution_policy seq{};
43constexpr parallel_execution_policy par{};
44
45namespace detail {
46
47#if LLVM_ENABLE_THREADS
48
49class Latch {
50  uint32_t Count;
51  mutable std::mutex Mutex;
52  mutable std::condition_variable Cond;
53
54public:
55  explicit Latch(uint32_t Count = 0) : Count(Count) {}
56  ~Latch() { sync(); }
57
58  void inc() {
59    std::unique_lock<std::mutex> lock(Mutex);
60    ++Count;
61  }
62
63  void dec() {
64    std::unique_lock<std::mutex> lock(Mutex);
65    if (--Count == 0)
66      Cond.notify_all();
67  }
68
69  void sync() const {
70    std::unique_lock<std::mutex> lock(Mutex);
71    Cond.wait(lock, [&] { return Count == 0; });
72  }
73};
74
75class TaskGroup {
76  Latch L;
77
78public:
79  void spawn(std::function<void()> f);
80
81  void sync() const { L.sync(); }
82};
83
84#if defined(_MSC_VER)
85template <class RandomAccessIterator, class Comparator>
86void parallel_sort(RandomAccessIterator Start, RandomAccessIterator End,
87                   const Comparator &Comp) {
88  concurrency::parallel_sort(Start, End, Comp);
89}
90template <class IterTy, class FuncTy>
91void parallel_for_each(IterTy Begin, IterTy End, FuncTy Fn) {
92  concurrency::parallel_for_each(Begin, End, Fn);
93}
94
95template <class IndexTy, class FuncTy>
96void parallel_for_each_n(IndexTy Begin, IndexTy End, FuncTy Fn) {
97  concurrency::parallel_for(Begin, End, Fn);
98}
99
100#else
101const ptrdiff_t MinParallelSize = 1024;
102
103/// \brief Inclusive median.
104template <class RandomAccessIterator, class Comparator>
105RandomAccessIterator medianOf3(RandomAccessIterator Start,
106                               RandomAccessIterator End,
107                               const Comparator &Comp) {
108  RandomAccessIterator Mid = Start + (std::distance(Start, End) / 2);
109  return Comp(*Start, *(End - 1))
110             ? (Comp(*Mid, *(End - 1)) ? (Comp(*Start, *Mid) ? Mid : Start)
111                                       : End - 1)
112             : (Comp(*Mid, *Start) ? (Comp(*(End - 1), *Mid) ? Mid : End - 1)
113                                   : Start);
114}
115
116template <class RandomAccessIterator, class Comparator>
117void parallel_quick_sort(RandomAccessIterator Start, RandomAccessIterator End,
118                         const Comparator &Comp, TaskGroup &TG, size_t Depth) {
119  // Do a sequential sort for small inputs.
120  if (std::distance(Start, End) < detail::MinParallelSize || Depth == 0) {
121    std::sort(Start, End, Comp);
122    return;
123  }
124
125  // Partition.
126  auto Pivot = medianOf3(Start, End, Comp);
127  // Move Pivot to End.
128  std::swap(*(End - 1), *Pivot);
129  Pivot = std::partition(Start, End - 1, [&Comp, End](decltype(*Start) V) {
130    return Comp(V, *(End - 1));
131  });
132  // Move Pivot to middle of partition.
133  std::swap(*Pivot, *(End - 1));
134
135  // Recurse.
136  TG.spawn([=, &Comp, &TG] {
137    parallel_quick_sort(Start, Pivot, Comp, TG, Depth - 1);
138  });
139  parallel_quick_sort(Pivot + 1, End, Comp, TG, Depth - 1);
140}
141
142template <class RandomAccessIterator, class Comparator>
143void parallel_sort(RandomAccessIterator Start, RandomAccessIterator End,
144                   const Comparator &Comp) {
145  TaskGroup TG;
146  parallel_quick_sort(Start, End, Comp, TG,
147                      llvm::Log2_64(std::distance(Start, End)) + 1);
148}
149
150template <class IterTy, class FuncTy>
151void parallel_for_each(IterTy Begin, IterTy End, FuncTy Fn) {
152  // TaskGroup has a relatively high overhead, so we want to reduce
153  // the number of spawn() calls. We'll create up to 1024 tasks here.
154  // (Note that 1024 is an arbitrary number. This code probably needs
155  // improving to take the number of available cores into account.)
156  ptrdiff_t TaskSize = std::distance(Begin, End) / 1024;
157  if (TaskSize == 0)
158    TaskSize = 1;
159
160  TaskGroup TG;
161  while (TaskSize <= std::distance(Begin, End)) {
162    TG.spawn([=, &Fn] { std::for_each(Begin, Begin + TaskSize, Fn); });
163    Begin += TaskSize;
164  }
165  TG.spawn([=, &Fn] { std::for_each(Begin, End, Fn); });
166}
167
168template <class IndexTy, class FuncTy>
169void parallel_for_each_n(IndexTy Begin, IndexTy End, FuncTy Fn) {
170  ptrdiff_t TaskSize = (End - Begin) / 1024;
171  if (TaskSize == 0)
172    TaskSize = 1;
173
174  TaskGroup TG;
175  IndexTy I = Begin;
176  for (; I + TaskSize < End; I += TaskSize) {
177    TG.spawn([=, &Fn] {
178      for (IndexTy J = I, E = I + TaskSize; J != E; ++J)
179        Fn(J);
180    });
181  }
182  TG.spawn([=, &Fn] {
183    for (IndexTy J = I; J < End; ++J)
184      Fn(J);
185  });
186}
187
188#endif
189
190#endif
191
192template <typename Iter>
193using DefComparator =
194    std::less<typename std::iterator_traits<Iter>::value_type>;
195
196} // namespace detail
197
198// sequential algorithm implementations.
199template <class Policy, class RandomAccessIterator,
200          class Comparator = detail::DefComparator<RandomAccessIterator>>
201void sort(Policy policy, RandomAccessIterator Start, RandomAccessIterator End,
202          const Comparator &Comp = Comparator()) {
203  static_assert(is_execution_policy<Policy>::value,
204                "Invalid execution policy!");
205  std::sort(Start, End, Comp);
206}
207
208template <class Policy, class IterTy, class FuncTy>
209void for_each(Policy policy, IterTy Begin, IterTy End, FuncTy Fn) {
210  static_assert(is_execution_policy<Policy>::value,
211                "Invalid execution policy!");
212  std::for_each(Begin, End, Fn);
213}
214
215template <class Policy, class IndexTy, class FuncTy>
216void for_each_n(Policy policy, IndexTy Begin, IndexTy End, FuncTy Fn) {
217  static_assert(is_execution_policy<Policy>::value,
218                "Invalid execution policy!");
219  for (IndexTy I = Begin; I != End; ++I)
220    Fn(I);
221}
222
223// Parallel algorithm implementations, only available when LLVM_ENABLE_THREADS
224// is true.
225#if LLVM_ENABLE_THREADS
226template <class RandomAccessIterator,
227          class Comparator = detail::DefComparator<RandomAccessIterator>>
228void sort(parallel_execution_policy policy, RandomAccessIterator Start,
229          RandomAccessIterator End, const Comparator &Comp = Comparator()) {
230  detail::parallel_sort(Start, End, Comp);
231}
232
233template <class IterTy, class FuncTy>
234void for_each(parallel_execution_policy policy, IterTy Begin, IterTy End,
235              FuncTy Fn) {
236  detail::parallel_for_each(Begin, End, Fn);
237}
238
239template <class IndexTy, class FuncTy>
240void for_each_n(parallel_execution_policy policy, IndexTy Begin, IndexTy End,
241                FuncTy Fn) {
242  detail::parallel_for_each_n(Begin, End, Fn);
243}
244#endif
245
246} // namespace parallel
247} // namespace llvm
248
249#endif // LLVM_SUPPORT_PARALLEL_H
250