1292915Sdim//===-- WebAssemblyRegColoring.cpp - Register coloring --------------------===//
2292915Sdim//
3353358Sdim// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4353358Sdim// See https://llvm.org/LICENSE.txt for license information.
5353358Sdim// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6292915Sdim//
7292915Sdim//===----------------------------------------------------------------------===//
8292915Sdim///
9292915Sdim/// \file
10341825Sdim/// This file implements a virtual register coloring pass.
11292915Sdim///
12292915Sdim/// WebAssembly doesn't have a fixed number of registers, but it is still
13292915Sdim/// desirable to minimize the total number of registers used in each function.
14292915Sdim///
15292915Sdim/// This code is modeled after lib/CodeGen/StackSlotColoring.cpp.
16292915Sdim///
17292915Sdim//===----------------------------------------------------------------------===//
18292915Sdim
19292915Sdim#include "WebAssembly.h"
20292915Sdim#include "WebAssemblyMachineFunctionInfo.h"
21327952Sdim#include "llvm/CodeGen/LiveIntervals.h"
22292915Sdim#include "llvm/CodeGen/MachineBlockFrequencyInfo.h"
23292915Sdim#include "llvm/CodeGen/MachineRegisterInfo.h"
24292915Sdim#include "llvm/CodeGen/Passes.h"
25292915Sdim#include "llvm/Support/Debug.h"
26292915Sdim#include "llvm/Support/raw_ostream.h"
27292915Sdimusing namespace llvm;
28292915Sdim
29292915Sdim#define DEBUG_TYPE "wasm-reg-coloring"
30292915Sdim
31292915Sdimnamespace {
32292915Sdimclass WebAssemblyRegColoring final : public MachineFunctionPass {
33292915Sdimpublic:
34292915Sdim  static char ID; // Pass identification, replacement for typeid
35292915Sdim  WebAssemblyRegColoring() : MachineFunctionPass(ID) {}
36292915Sdim
37314564Sdim  StringRef getPassName() const override {
38292915Sdim    return "WebAssembly Register Coloring";
39292915Sdim  }
40292915Sdim
41292915Sdim  void getAnalysisUsage(AnalysisUsage &AU) const override {
42292915Sdim    AU.setPreservesCFG();
43292915Sdim    AU.addRequired<LiveIntervals>();
44292915Sdim    AU.addRequired<MachineBlockFrequencyInfo>();
45292915Sdim    AU.addPreserved<MachineBlockFrequencyInfo>();
46292915Sdim    AU.addPreservedID(MachineDominatorsID);
47292915Sdim    MachineFunctionPass::getAnalysisUsage(AU);
48292915Sdim  }
49292915Sdim
50292915Sdim  bool runOnMachineFunction(MachineFunction &MF) override;
51292915Sdim
52292915Sdimprivate:
53292915Sdim};
54292915Sdim} // end anonymous namespace
55292915Sdim
56292915Sdimchar WebAssemblyRegColoring::ID = 0;
57341825SdimINITIALIZE_PASS(WebAssemblyRegColoring, DEBUG_TYPE,
58341825Sdim                "Minimize number of registers used", false, false)
59341825Sdim
60292915SdimFunctionPass *llvm::createWebAssemblyRegColoring() {
61292915Sdim  return new WebAssemblyRegColoring();
62292915Sdim}
63292915Sdim
64292915Sdim// Compute the total spill weight for VReg.
65292915Sdimstatic float computeWeight(const MachineRegisterInfo *MRI,
66292915Sdim                           const MachineBlockFrequencyInfo *MBFI,
67292915Sdim                           unsigned VReg) {
68353358Sdim  float Weight = 0.0f;
69292915Sdim  for (MachineOperand &MO : MRI->reg_nodbg_operands(VReg))
70353358Sdim    Weight += LiveIntervals::getSpillWeight(MO.isDef(), MO.isUse(), MBFI,
71309124Sdim                                            *MO.getParent());
72353358Sdim  return Weight;
73292915Sdim}
74292915Sdim
75292915Sdimbool WebAssemblyRegColoring::runOnMachineFunction(MachineFunction &MF) {
76341825Sdim  LLVM_DEBUG({
77292915Sdim    dbgs() << "********** Register Coloring **********\n"
78292915Sdim           << "********** Function: " << MF.getName() << '\n';
79292915Sdim  });
80292915Sdim
81292915Sdim  // If there are calls to setjmp or sigsetjmp, don't perform coloring. Virtual
82292915Sdim  // registers could be modified before the longjmp is executed, resulting in
83292915Sdim  // the wrong value being used afterwards. (See <rdar://problem/8007500>.)
84292915Sdim  // TODO: Does WebAssembly need to care about setjmp for register coloring?
85292915Sdim  if (MF.exposesReturnsTwice())
86292915Sdim    return false;
87292915Sdim
88292915Sdim  MachineRegisterInfo *MRI = &MF.getRegInfo();
89292915Sdim  LiveIntervals *Liveness = &getAnalysis<LiveIntervals>();
90292915Sdim  const MachineBlockFrequencyInfo *MBFI =
91292915Sdim      &getAnalysis<MachineBlockFrequencyInfo>();
92292915Sdim  WebAssemblyFunctionInfo &MFI = *MF.getInfo<WebAssemblyFunctionInfo>();
93292915Sdim
94292915Sdim  // Gather all register intervals into a list and sort them.
95292915Sdim  unsigned NumVRegs = MRI->getNumVirtRegs();
96292915Sdim  SmallVector<LiveInterval *, 0> SortedIntervals;
97292915Sdim  SortedIntervals.reserve(NumVRegs);
98292915Sdim
99341825Sdim  LLVM_DEBUG(dbgs() << "Interesting register intervals:\n");
100353358Sdim  for (unsigned I = 0; I < NumVRegs; ++I) {
101360784Sdim    unsigned VReg = Register::index2VirtReg(I);
102292915Sdim    if (MFI.isVRegStackified(VReg))
103292915Sdim      continue;
104309124Sdim    // Skip unused registers, which can use $drop.
105292915Sdim    if (MRI->use_empty(VReg))
106292915Sdim      continue;
107292915Sdim
108292915Sdim    LiveInterval *LI = &Liveness->getInterval(VReg);
109292915Sdim    assert(LI->weight == 0.0f);
110292915Sdim    LI->weight = computeWeight(MRI, MBFI, VReg);
111341825Sdim    LLVM_DEBUG(LI->dump());
112292915Sdim    SortedIntervals.push_back(LI);
113292915Sdim  }
114341825Sdim  LLVM_DEBUG(dbgs() << '\n');
115292915Sdim
116292915Sdim  // Sort them to put arguments first (since we don't want to rename live-in
117292915Sdim  // registers), by weight next, and then by position.
118292915Sdim  // TODO: Investigate more intelligent sorting heuristics. For starters, we
119292915Sdim  // should try to coalesce adjacent live intervals before non-adjacent ones.
120344779Sdim  llvm::sort(SortedIntervals, [MRI](LiveInterval *LHS, LiveInterval *RHS) {
121344779Sdim    if (MRI->isLiveIn(LHS->reg) != MRI->isLiveIn(RHS->reg))
122344779Sdim      return MRI->isLiveIn(LHS->reg);
123344779Sdim    if (LHS->weight != RHS->weight)
124344779Sdim      return LHS->weight > RHS->weight;
125344779Sdim    if (LHS->empty() || RHS->empty())
126344779Sdim      return !LHS->empty() && RHS->empty();
127344779Sdim    return *LHS < *RHS;
128344779Sdim  });
129292915Sdim
130341825Sdim  LLVM_DEBUG(dbgs() << "Coloring register intervals:\n");
131292915Sdim  SmallVector<unsigned, 16> SlotMapping(SortedIntervals.size(), -1u);
132292915Sdim  SmallVector<SmallVector<LiveInterval *, 4>, 16> Assignments(
133292915Sdim      SortedIntervals.size());
134292915Sdim  BitVector UsedColors(SortedIntervals.size());
135292915Sdim  bool Changed = false;
136353358Sdim  for (size_t I = 0, E = SortedIntervals.size(); I < E; ++I) {
137353358Sdim    LiveInterval *LI = SortedIntervals[I];
138292915Sdim    unsigned Old = LI->reg;
139353358Sdim    size_t Color = I;
140292915Sdim    const TargetRegisterClass *RC = MRI->getRegClass(Old);
141292915Sdim
142292915Sdim    // Check if it's possible to reuse any of the used colors.
143292915Sdim    if (!MRI->isLiveIn(Old))
144321369Sdim      for (unsigned C : UsedColors.set_bits()) {
145292915Sdim        if (MRI->getRegClass(SortedIntervals[C]->reg) != RC)
146292915Sdim          continue;
147292915Sdim        for (LiveInterval *OtherLI : Assignments[C])
148292915Sdim          if (!OtherLI->empty() && OtherLI->overlaps(*LI))
149292915Sdim            goto continue_outer;
150292915Sdim        Color = C;
151292915Sdim        break;
152292915Sdim      continue_outer:;
153292915Sdim      }
154292915Sdim
155292915Sdim    unsigned New = SortedIntervals[Color]->reg;
156353358Sdim    SlotMapping[I] = New;
157292915Sdim    Changed |= Old != New;
158292915Sdim    UsedColors.set(Color);
159292915Sdim    Assignments[Color].push_back(LI);
160360784Sdim    LLVM_DEBUG(dbgs() << "Assigning vreg" << Register::virtReg2Index(LI->reg)
161360784Sdim                      << " to vreg" << Register::virtReg2Index(New) << "\n");
162292915Sdim  }
163292915Sdim  if (!Changed)
164292915Sdim    return false;
165292915Sdim
166292915Sdim  // Rewrite register operands.
167353358Sdim  for (size_t I = 0, E = SortedIntervals.size(); I < E; ++I) {
168353358Sdim    unsigned Old = SortedIntervals[I]->reg;
169353358Sdim    unsigned New = SlotMapping[I];
170292915Sdim    if (Old != New)
171292915Sdim      MRI->replaceRegWith(Old, New);
172292915Sdim  }
173292915Sdim  return true;
174292915Sdim}
175