1//===- BalancedPartitioning.cpp -------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements BalancedPartitioning, a recursive balanced graph
10// partitioning algorithm.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/Support/BalancedPartitioning.h"
15#include "llvm/Support/Debug.h"
16#include "llvm/Support/Format.h"
17#include "llvm/Support/FormatVariadic.h"
18#include "llvm/Support/ThreadPool.h"
19
20using namespace llvm;
21#define DEBUG_TYPE "balanced-partitioning"
22
23void BPFunctionNode::dump(raw_ostream &OS) const {
24  OS << formatv("{{ID={0} Utilities={{{1:$[,]}} Bucket={2}}", Id,
25                make_range(UtilityNodes.begin(), UtilityNodes.end()), Bucket);
26}
27
28template <typename Func>
29void BalancedPartitioning::BPThreadPool::async(Func &&F) {
30#if LLVM_ENABLE_THREADS
31  // This new thread could spawn more threads, so mark it as active
32  ++NumActiveThreads;
33  TheThreadPool.async([=]() {
34    // Run the task
35    F();
36
37    // This thread will no longer spawn new threads, so mark it as inactive
38    if (--NumActiveThreads == 0) {
39      // There are no more active threads, so mark as finished and notify
40      {
41        std::unique_lock<std::mutex> lock(mtx);
42        assert(!IsFinishedSpawning);
43        IsFinishedSpawning = true;
44      }
45      cv.notify_one();
46    }
47  });
48#else
49  llvm_unreachable("threads are disabled");
50#endif
51}
52
53void BalancedPartitioning::BPThreadPool::wait() {
54#if LLVM_ENABLE_THREADS
55  // TODO: We could remove the mutex and condition variable and use
56  // std::atomic::wait() instead, but that isn't available until C++20
57  {
58    std::unique_lock<std::mutex> lock(mtx);
59    cv.wait(lock, [&]() { return IsFinishedSpawning; });
60    assert(IsFinishedSpawning && NumActiveThreads == 0);
61  }
62  // Now we can call ThreadPool::wait() since all tasks have been submitted
63  TheThreadPool.wait();
64#else
65  llvm_unreachable("threads are disabled");
66#endif
67}
68
69BalancedPartitioning::BalancedPartitioning(
70    const BalancedPartitioningConfig &Config)
71    : Config(Config) {
72  // Pre-computing log2 values
73  Log2Cache[0] = 0.0;
74  for (unsigned I = 1; I < LOG_CACHE_SIZE; I++)
75    Log2Cache[I] = std::log2(I);
76}
77
78void BalancedPartitioning::run(std::vector<BPFunctionNode> &Nodes) const {
79  LLVM_DEBUG(
80      dbgs() << format(
81          "Partitioning %d nodes using depth %d and %d iterations per split\n",
82          Nodes.size(), Config.SplitDepth, Config.IterationsPerSplit));
83  std::optional<BPThreadPool> TP;
84#if LLVM_ENABLE_THREADS
85  ThreadPool TheThreadPool;
86  if (Config.TaskSplitDepth > 1)
87    TP.emplace(TheThreadPool);
88#endif
89
90  // Record the input order
91  for (unsigned I = 0; I < Nodes.size(); I++)
92    Nodes[I].InputOrderIndex = I;
93
94  auto NodesRange = llvm::make_range(Nodes.begin(), Nodes.end());
95  auto BisectTask = [=, &TP]() {
96    bisect(NodesRange, /*RecDepth=*/0, /*RootBucket=*/1, /*Offset=*/0, TP);
97  };
98  if (TP) {
99    TP->async(std::move(BisectTask));
100    TP->wait();
101  } else {
102    BisectTask();
103  }
104
105  llvm::stable_sort(NodesRange, [](const auto &L, const auto &R) {
106    return L.Bucket < R.Bucket;
107  });
108
109  LLVM_DEBUG(dbgs() << "Balanced partitioning completed\n");
110}
111
112void BalancedPartitioning::bisect(const FunctionNodeRange Nodes,
113                                  unsigned RecDepth, unsigned RootBucket,
114                                  unsigned Offset,
115                                  std::optional<BPThreadPool> &TP) const {
116  unsigned NumNodes = std::distance(Nodes.begin(), Nodes.end());
117  if (NumNodes <= 1 || RecDepth >= Config.SplitDepth) {
118    // We've reach the lowest level of the recursion tree. Fall back to the
119    // original order and assign to buckets.
120    llvm::sort(Nodes, [](const auto &L, const auto &R) {
121      return L.InputOrderIndex < R.InputOrderIndex;
122    });
123    for (auto &N : Nodes)
124      N.Bucket = Offset++;
125    return;
126  }
127
128  LLVM_DEBUG(dbgs() << format("Bisect with %d nodes and root bucket %d\n",
129                              NumNodes, RootBucket));
130
131  std::mt19937 RNG(RootBucket);
132
133  unsigned LeftBucket = 2 * RootBucket;
134  unsigned RightBucket = 2 * RootBucket + 1;
135
136  // Split into two and assign to the left and right buckets
137  split(Nodes, LeftBucket);
138
139  runIterations(Nodes, RecDepth, LeftBucket, RightBucket, RNG);
140
141  // Split nodes wrt the resulting buckets
142  auto NodesMid =
143      llvm::partition(Nodes, [&](auto &N) { return N.Bucket == LeftBucket; });
144  unsigned MidOffset = Offset + std::distance(Nodes.begin(), NodesMid);
145
146  auto LeftNodes = llvm::make_range(Nodes.begin(), NodesMid);
147  auto RightNodes = llvm::make_range(NodesMid, Nodes.end());
148
149  auto LeftRecTask = [=, &TP]() {
150    bisect(LeftNodes, RecDepth + 1, LeftBucket, Offset, TP);
151  };
152  auto RightRecTask = [=, &TP]() {
153    bisect(RightNodes, RecDepth + 1, RightBucket, MidOffset, TP);
154  };
155
156  if (TP && RecDepth < Config.TaskSplitDepth && NumNodes >= 4) {
157    TP->async(std::move(LeftRecTask));
158    TP->async(std::move(RightRecTask));
159  } else {
160    LeftRecTask();
161    RightRecTask();
162  }
163}
164
165void BalancedPartitioning::runIterations(const FunctionNodeRange Nodes,
166                                         unsigned RecDepth, unsigned LeftBucket,
167                                         unsigned RightBucket,
168                                         std::mt19937 &RNG) const {
169  unsigned NumNodes = std::distance(Nodes.begin(), Nodes.end());
170  DenseMap<BPFunctionNode::UtilityNodeT, unsigned> UtilityNodeIndex;
171  for (auto &N : Nodes)
172    for (auto &UN : N.UtilityNodes)
173      ++UtilityNodeIndex[UN];
174  // Remove utility nodes if they have just one edge or are connected to all
175  // functions
176  for (auto &N : Nodes)
177    llvm::erase_if(N.UtilityNodes, [&](auto &UN) {
178      return UtilityNodeIndex[UN] == 1 || UtilityNodeIndex[UN] == NumNodes;
179    });
180
181  // Renumber utility nodes so they can be used to index into Signatures
182  UtilityNodeIndex.clear();
183  for (auto &N : Nodes)
184    for (auto &UN : N.UtilityNodes)
185      UN = UtilityNodeIndex.insert({UN, UtilityNodeIndex.size()}).first->second;
186
187  // Initialize signatures
188  SignaturesT Signatures(/*Size=*/UtilityNodeIndex.size());
189  for (auto &N : Nodes) {
190    for (auto &UN : N.UtilityNodes) {
191      assert(UN < Signatures.size());
192      if (N.Bucket == LeftBucket) {
193        Signatures[UN].LeftCount++;
194      } else {
195        Signatures[UN].RightCount++;
196      }
197    }
198  }
199
200  for (unsigned I = 0; I < Config.IterationsPerSplit; I++) {
201    unsigned NumMovedNodes =
202        runIteration(Nodes, LeftBucket, RightBucket, Signatures, RNG);
203    if (NumMovedNodes == 0)
204      break;
205  }
206}
207
208unsigned BalancedPartitioning::runIteration(const FunctionNodeRange Nodes,
209                                            unsigned LeftBucket,
210                                            unsigned RightBucket,
211                                            SignaturesT &Signatures,
212                                            std::mt19937 &RNG) const {
213  // Init signature cost caches
214  for (auto &Signature : Signatures) {
215    if (Signature.CachedGainIsValid)
216      continue;
217    unsigned L = Signature.LeftCount;
218    unsigned R = Signature.RightCount;
219    assert((L > 0 || R > 0) && "incorrect signature");
220    float Cost = logCost(L, R);
221    Signature.CachedGainLR = 0.f;
222    Signature.CachedGainRL = 0.f;
223    if (L > 0)
224      Signature.CachedGainLR = Cost - logCost(L - 1, R + 1);
225    if (R > 0)
226      Signature.CachedGainRL = Cost - logCost(L + 1, R - 1);
227    Signature.CachedGainIsValid = true;
228  }
229
230  // Compute move gains
231  typedef std::pair<float, BPFunctionNode *> GainPair;
232  std::vector<GainPair> Gains;
233  for (auto &N : Nodes) {
234    bool FromLeftToRight = (N.Bucket == LeftBucket);
235    float Gain = moveGain(N, FromLeftToRight, Signatures);
236    Gains.push_back(std::make_pair(Gain, &N));
237  }
238
239  // Collect left and right gains
240  auto LeftEnd = llvm::partition(
241      Gains, [&](const auto &GP) { return GP.second->Bucket == LeftBucket; });
242  auto LeftRange = llvm::make_range(Gains.begin(), LeftEnd);
243  auto RightRange = llvm::make_range(LeftEnd, Gains.end());
244
245  // Sort gains in descending order
246  auto LargerGain = [](const auto &L, const auto &R) {
247    return L.first > R.first;
248  };
249  llvm::stable_sort(LeftRange, LargerGain);
250  llvm::stable_sort(RightRange, LargerGain);
251
252  unsigned NumMovedDataVertices = 0;
253  for (auto [LeftPair, RightPair] : llvm::zip(LeftRange, RightRange)) {
254    auto &[LeftGain, LeftNode] = LeftPair;
255    auto &[RightGain, RightNode] = RightPair;
256    // Stop when the gain is no longer beneficial
257    if (LeftGain + RightGain <= 0.f)
258      break;
259    // Try to exchange the nodes between buckets
260    if (moveFunctionNode(*LeftNode, LeftBucket, RightBucket, Signatures, RNG))
261      ++NumMovedDataVertices;
262    if (moveFunctionNode(*RightNode, LeftBucket, RightBucket, Signatures, RNG))
263      ++NumMovedDataVertices;
264  }
265  return NumMovedDataVertices;
266}
267
268bool BalancedPartitioning::moveFunctionNode(BPFunctionNode &N,
269                                            unsigned LeftBucket,
270                                            unsigned RightBucket,
271                                            SignaturesT &Signatures,
272                                            std::mt19937 &RNG) const {
273  // Sometimes we skip the move. This helps to escape local optima
274  if (std::uniform_real_distribution<float>(0.f, 1.f)(RNG) <=
275      Config.SkipProbability)
276    return false;
277
278  bool FromLeftToRight = (N.Bucket == LeftBucket);
279  // Update the current bucket
280  N.Bucket = (FromLeftToRight ? RightBucket : LeftBucket);
281
282  // Update signatures and invalidate gain cache
283  if (FromLeftToRight) {
284    for (auto &UN : N.UtilityNodes) {
285      auto &Signature = Signatures[UN];
286      Signature.LeftCount--;
287      Signature.RightCount++;
288      Signature.CachedGainIsValid = false;
289    }
290  } else {
291    for (auto &UN : N.UtilityNodes) {
292      auto &Signature = Signatures[UN];
293      Signature.LeftCount++;
294      Signature.RightCount--;
295      Signature.CachedGainIsValid = false;
296    }
297  }
298  return true;
299}
300
301void BalancedPartitioning::split(const FunctionNodeRange Nodes,
302                                 unsigned StartBucket) const {
303  unsigned NumNodes = std::distance(Nodes.begin(), Nodes.end());
304  auto NodesMid = Nodes.begin() + (NumNodes + 1) / 2;
305
306  std::nth_element(Nodes.begin(), NodesMid, Nodes.end(), [](auto &L, auto &R) {
307    return L.InputOrderIndex < R.InputOrderIndex;
308  });
309
310  for (auto &N : llvm::make_range(Nodes.begin(), NodesMid))
311    N.Bucket = StartBucket;
312  for (auto &N : llvm::make_range(NodesMid, Nodes.end()))
313    N.Bucket = StartBucket + 1;
314}
315
316float BalancedPartitioning::moveGain(const BPFunctionNode &N,
317                                     bool FromLeftToRight,
318                                     const SignaturesT &Signatures) {
319  float Gain = 0.f;
320  for (auto &UN : N.UtilityNodes)
321    Gain += (FromLeftToRight ? Signatures[UN].CachedGainLR
322                             : Signatures[UN].CachedGainRL);
323  return Gain;
324}
325
326float BalancedPartitioning::logCost(unsigned X, unsigned Y) const {
327  return -(X * log2Cached(X + 1) + Y * log2Cached(Y + 1));
328}
329
330float BalancedPartitioning::log2Cached(unsigned i) const {
331  return (i < LOG_CACHE_SIZE) ? Log2Cache[i] : std::log2(i);
332}
333