Parallel.h revision 360660
1//===- llvm/Support/Parallel.h - Parallel algorithms ----------------------===// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8 9#ifndef LLVM_SUPPORT_PARALLEL_H 10#define LLVM_SUPPORT_PARALLEL_H 11 12#include "llvm/ADT/STLExtras.h" 13#include "llvm/Config/llvm-config.h" 14#include "llvm/Support/MathExtras.h" 15 16#include <algorithm> 17#include <condition_variable> 18#include <functional> 19#include <mutex> 20 21#if defined(_MSC_VER) && LLVM_ENABLE_THREADS 22#pragma warning(push) 23#pragma warning(disable : 4530) 24#include <concrt.h> 25#include <ppl.h> 26#pragma warning(pop) 27#endif 28 29namespace llvm { 30 31namespace parallel { 32struct sequential_execution_policy {}; 33struct parallel_execution_policy {}; 34 35template <typename T> 36struct is_execution_policy 37 : public std::integral_constant< 38 bool, llvm::is_one_of<T, sequential_execution_policy, 39 parallel_execution_policy>::value> {}; 40 41constexpr sequential_execution_policy seq{}; 42constexpr parallel_execution_policy par{}; 43 44namespace detail { 45 46#if LLVM_ENABLE_THREADS 47 48class Latch { 49 uint32_t Count; 50 mutable std::mutex Mutex; 51 mutable std::condition_variable Cond; 52 53public: 54 explicit Latch(uint32_t Count = 0) : Count(Count) {} 55 ~Latch() { sync(); } 56 57 void inc() { 58 std::lock_guard<std::mutex> lock(Mutex); 59 ++Count; 60 } 61 62 void dec() { 63 std::lock_guard<std::mutex> lock(Mutex); 64 if (--Count == 0) 65 Cond.notify_all(); 66 } 67 68 void sync() const { 69 std::unique_lock<std::mutex> lock(Mutex); 70 Cond.wait(lock, [&] { return Count == 0; }); 71 } 72}; 73 74class TaskGroup { 75 Latch L; 76 bool Parallel; 77 78public: 79 TaskGroup(); 80 ~TaskGroup(); 81 82 void spawn(std::function<void()> f); 83 84 void sync() const { L.sync(); } 85}; 86 87#if defined(_MSC_VER) 88template <class RandomAccessIterator, class Comparator> 89void parallel_sort(RandomAccessIterator Start, RandomAccessIterator End, 90 const Comparator &Comp) { 91 concurrency::parallel_sort(Start, End, Comp); 92} 93template <class IterTy, class FuncTy> 94void parallel_for_each(IterTy Begin, IterTy End, FuncTy Fn) { 95 concurrency::parallel_for_each(Begin, End, Fn); 96} 97 98template <class IndexTy, class FuncTy> 99void parallel_for_each_n(IndexTy Begin, IndexTy End, FuncTy Fn) { 100 concurrency::parallel_for(Begin, End, Fn); 101} 102 103#else 104const ptrdiff_t MinParallelSize = 1024; 105 106/// Inclusive median. 107template <class RandomAccessIterator, class Comparator> 108RandomAccessIterator medianOf3(RandomAccessIterator Start, 109 RandomAccessIterator End, 110 const Comparator &Comp) { 111 RandomAccessIterator Mid = Start + (std::distance(Start, End) / 2); 112 return Comp(*Start, *(End - 1)) 113 ? (Comp(*Mid, *(End - 1)) ? (Comp(*Start, *Mid) ? Mid : Start) 114 : End - 1) 115 : (Comp(*Mid, *Start) ? (Comp(*(End - 1), *Mid) ? Mid : End - 1) 116 : Start); 117} 118 119template <class RandomAccessIterator, class Comparator> 120void parallel_quick_sort(RandomAccessIterator Start, RandomAccessIterator End, 121 const Comparator &Comp, TaskGroup &TG, size_t Depth) { 122 // Do a sequential sort for small inputs. 123 if (std::distance(Start, End) < detail::MinParallelSize || Depth == 0) { 124 llvm::sort(Start, End, Comp); 125 return; 126 } 127 128 // Partition. 129 auto Pivot = medianOf3(Start, End, Comp); 130 // Move Pivot to End. 131 std::swap(*(End - 1), *Pivot); 132 Pivot = std::partition(Start, End - 1, [&Comp, End](decltype(*Start) V) { 133 return Comp(V, *(End - 1)); 134 }); 135 // Move Pivot to middle of partition. 136 std::swap(*Pivot, *(End - 1)); 137 138 // Recurse. 139 TG.spawn([=, &Comp, &TG] { 140 parallel_quick_sort(Start, Pivot, Comp, TG, Depth - 1); 141 }); 142 parallel_quick_sort(Pivot + 1, End, Comp, TG, Depth - 1); 143} 144 145template <class RandomAccessIterator, class Comparator> 146void parallel_sort(RandomAccessIterator Start, RandomAccessIterator End, 147 const Comparator &Comp) { 148 TaskGroup TG; 149 parallel_quick_sort(Start, End, Comp, TG, 150 llvm::Log2_64(std::distance(Start, End)) + 1); 151} 152 153template <class IterTy, class FuncTy> 154void parallel_for_each(IterTy Begin, IterTy End, FuncTy Fn) { 155 // TaskGroup has a relatively high overhead, so we want to reduce 156 // the number of spawn() calls. We'll create up to 1024 tasks here. 157 // (Note that 1024 is an arbitrary number. This code probably needs 158 // improving to take the number of available cores into account.) 159 ptrdiff_t TaskSize = std::distance(Begin, End) / 1024; 160 if (TaskSize == 0) 161 TaskSize = 1; 162 163 TaskGroup TG; 164 while (TaskSize < std::distance(Begin, End)) { 165 TG.spawn([=, &Fn] { std::for_each(Begin, Begin + TaskSize, Fn); }); 166 Begin += TaskSize; 167 } 168 std::for_each(Begin, End, Fn); 169} 170 171template <class IndexTy, class FuncTy> 172void parallel_for_each_n(IndexTy Begin, IndexTy End, FuncTy Fn) { 173 ptrdiff_t TaskSize = (End - Begin) / 1024; 174 if (TaskSize == 0) 175 TaskSize = 1; 176 177 TaskGroup TG; 178 IndexTy I = Begin; 179 for (; I + TaskSize < End; I += TaskSize) { 180 TG.spawn([=, &Fn] { 181 for (IndexTy J = I, E = I + TaskSize; J != E; ++J) 182 Fn(J); 183 }); 184 } 185 for (IndexTy J = I; J < End; ++J) 186 Fn(J); 187} 188 189#endif 190 191#endif 192 193template <typename Iter> 194using DefComparator = 195 std::less<typename std::iterator_traits<Iter>::value_type>; 196 197} // namespace detail 198 199// sequential algorithm implementations. 200template <class Policy, class RandomAccessIterator, 201 class Comparator = detail::DefComparator<RandomAccessIterator>> 202void sort(Policy policy, RandomAccessIterator Start, RandomAccessIterator End, 203 const Comparator &Comp = Comparator()) { 204 static_assert(is_execution_policy<Policy>::value, 205 "Invalid execution policy!"); 206 llvm::sort(Start, End, Comp); 207} 208 209template <class Policy, class IterTy, class FuncTy> 210void for_each(Policy policy, IterTy Begin, IterTy End, FuncTy Fn) { 211 static_assert(is_execution_policy<Policy>::value, 212 "Invalid execution policy!"); 213 std::for_each(Begin, End, Fn); 214} 215 216template <class Policy, class IndexTy, class FuncTy> 217void for_each_n(Policy policy, IndexTy Begin, IndexTy End, FuncTy Fn) { 218 static_assert(is_execution_policy<Policy>::value, 219 "Invalid execution policy!"); 220 for (IndexTy I = Begin; I != End; ++I) 221 Fn(I); 222} 223 224// Parallel algorithm implementations, only available when LLVM_ENABLE_THREADS 225// is true. 226#if LLVM_ENABLE_THREADS 227template <class RandomAccessIterator, 228 class Comparator = detail::DefComparator<RandomAccessIterator>> 229void sort(parallel_execution_policy policy, RandomAccessIterator Start, 230 RandomAccessIterator End, const Comparator &Comp = Comparator()) { 231 detail::parallel_sort(Start, End, Comp); 232} 233 234template <class IterTy, class FuncTy> 235void for_each(parallel_execution_policy policy, IterTy Begin, IterTy End, 236 FuncTy Fn) { 237 detail::parallel_for_each(Begin, End, Fn); 238} 239 240template <class IndexTy, class FuncTy> 241void for_each_n(parallel_execution_policy policy, IndexTy Begin, IndexTy End, 242 FuncTy Fn) { 243 detail::parallel_for_each_n(Begin, End, Fn); 244} 245#endif 246 247} // namespace parallel 248} // namespace llvm 249 250#endif // LLVM_SUPPORT_PARALLEL_H 251