1//=== WebAssemblyLateEHPrepare.cpp - WebAssembly Exception Preparation -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// \brief Does various transformations for exception handling.
11///
12//===----------------------------------------------------------------------===//
13
14#include "MCTargetDesc/WebAssemblyMCTargetDesc.h"
15#include "Utils/WebAssemblyUtilities.h"
16#include "WebAssembly.h"
17#include "WebAssemblySubtarget.h"
18#include "llvm/ADT/SmallPtrSet.h"
19#include "llvm/CodeGen/MachineFunctionPass.h"
20#include "llvm/CodeGen/MachineInstrBuilder.h"
21#include "llvm/CodeGen/WasmEHFuncInfo.h"
22#include "llvm/MC/MCAsmInfo.h"
23#include "llvm/Support/Debug.h"
24#include "llvm/Target/TargetMachine.h"
25using namespace llvm;
26
27#define DEBUG_TYPE "wasm-late-eh-prepare"
28
29namespace {
30class WebAssemblyLateEHPrepare final : public MachineFunctionPass {
31  StringRef getPassName() const override {
32    return "WebAssembly Late Prepare Exception";
33  }
34
35  bool runOnMachineFunction(MachineFunction &MF) override;
36  bool removeUnreachableEHPads(MachineFunction &MF);
37  void recordCatchRetBBs(MachineFunction &MF);
38  bool hoistCatches(MachineFunction &MF);
39  bool addCatchAlls(MachineFunction &MF);
40  bool replaceFuncletReturns(MachineFunction &MF);
41  bool removeUnnecessaryUnreachables(MachineFunction &MF);
42  bool restoreStackPointer(MachineFunction &MF);
43
44  MachineBasicBlock *getMatchingEHPad(MachineInstr *MI);
45  SmallPtrSet<MachineBasicBlock *, 8> CatchRetBBs;
46
47public:
48  static char ID; // Pass identification, replacement for typeid
49  WebAssemblyLateEHPrepare() : MachineFunctionPass(ID) {}
50};
51} // end anonymous namespace
52
53char WebAssemblyLateEHPrepare::ID = 0;
54INITIALIZE_PASS(WebAssemblyLateEHPrepare, DEBUG_TYPE,
55                "WebAssembly Late Exception Preparation", false, false)
56
57FunctionPass *llvm::createWebAssemblyLateEHPrepare() {
58  return new WebAssemblyLateEHPrepare();
59}
60
61// Returns the nearest EH pad that dominates this instruction. This does not use
62// dominator analysis; it just does BFS on its predecessors until arriving at an
63// EH pad. This assumes valid EH scopes so the first EH pad it arrives in all
64// possible search paths should be the same.
65// Returns nullptr in case it does not find any EH pad in the search, or finds
66// multiple different EH pads.
67MachineBasicBlock *
68WebAssemblyLateEHPrepare::getMatchingEHPad(MachineInstr *MI) {
69  MachineFunction *MF = MI->getParent()->getParent();
70  SmallVector<MachineBasicBlock *, 2> WL;
71  SmallPtrSet<MachineBasicBlock *, 2> Visited;
72  WL.push_back(MI->getParent());
73  MachineBasicBlock *EHPad = nullptr;
74  while (!WL.empty()) {
75    MachineBasicBlock *MBB = WL.pop_back_val();
76    if (!Visited.insert(MBB).second)
77      continue;
78    if (MBB->isEHPad()) {
79      if (EHPad && EHPad != MBB)
80        return nullptr;
81      EHPad = MBB;
82      continue;
83    }
84    if (MBB == &MF->front())
85      return nullptr;
86    for (auto *Pred : MBB->predecessors())
87      if (!CatchRetBBs.count(Pred)) // We don't go into child scopes
88        WL.push_back(Pred);
89  }
90  return EHPad;
91}
92
93// Erase the specified BBs if the BB does not have any remaining predecessors,
94// and also all its dead children.
95template <typename Container>
96static void eraseDeadBBsAndChildren(const Container &MBBs) {
97  SmallVector<MachineBasicBlock *, 8> WL(MBBs.begin(), MBBs.end());
98  SmallPtrSet<MachineBasicBlock *, 8> Deleted;
99  while (!WL.empty()) {
100    MachineBasicBlock *MBB = WL.pop_back_val();
101    if (Deleted.count(MBB) || !MBB->pred_empty())
102      continue;
103    SmallVector<MachineBasicBlock *, 4> Succs(MBB->successors());
104    WL.append(MBB->succ_begin(), MBB->succ_end());
105    for (auto *Succ : Succs)
106      MBB->removeSuccessor(Succ);
107    // To prevent deleting the same BB multiple times, which can happen when
108    // 'MBBs' contain both a parent and a child
109    Deleted.insert(MBB);
110    MBB->eraseFromParent();
111  }
112}
113
114bool WebAssemblyLateEHPrepare::runOnMachineFunction(MachineFunction &MF) {
115  LLVM_DEBUG(dbgs() << "********** Late EH Prepare **********\n"
116                       "********** Function: "
117                    << MF.getName() << '\n');
118
119  if (MF.getTarget().getMCAsmInfo()->getExceptionHandlingType() !=
120      ExceptionHandling::Wasm)
121    return false;
122
123  bool Changed = false;
124  if (MF.getFunction().hasPersonalityFn()) {
125    Changed |= removeUnreachableEHPads(MF);
126    recordCatchRetBBs(MF);
127    Changed |= hoistCatches(MF);
128    Changed |= addCatchAlls(MF);
129    Changed |= replaceFuncletReturns(MF);
130  }
131  Changed |= removeUnnecessaryUnreachables(MF);
132  if (MF.getFunction().hasPersonalityFn())
133    Changed |= restoreStackPointer(MF);
134  return Changed;
135}
136
137// Remove unreachable EH pads and its children. If they remain, CFG
138// stackification can be tricky.
139bool WebAssemblyLateEHPrepare::removeUnreachableEHPads(MachineFunction &MF) {
140  SmallVector<MachineBasicBlock *, 4> ToDelete;
141  for (auto &MBB : MF)
142    if (MBB.isEHPad() && MBB.pred_empty())
143      ToDelete.push_back(&MBB);
144  eraseDeadBBsAndChildren(ToDelete);
145  return !ToDelete.empty();
146}
147
148// Record which BB ends with catchret instruction, because this will be replaced
149// with 'br's later. This set of catchret BBs is necessary in 'getMatchingEHPad'
150// function.
151void WebAssemblyLateEHPrepare::recordCatchRetBBs(MachineFunction &MF) {
152  CatchRetBBs.clear();
153  for (auto &MBB : MF) {
154    auto Pos = MBB.getFirstTerminator();
155    if (Pos == MBB.end())
156      continue;
157    MachineInstr *TI = &*Pos;
158    if (TI->getOpcode() == WebAssembly::CATCHRET)
159      CatchRetBBs.insert(&MBB);
160  }
161}
162
163// Hoist catch instructions to the beginning of their matching EH pad BBs in
164// case,
165// (1) catch instruction is not the first instruction in EH pad.
166// ehpad:
167//   some_other_instruction
168//   ...
169//   %exn = catch 0
170// (2) catch instruction is in a non-EH pad BB. For example,
171// ehpad:
172//   br bb0
173// bb0:
174//   %exn = catch 0
175bool WebAssemblyLateEHPrepare::hoistCatches(MachineFunction &MF) {
176  bool Changed = false;
177  SmallVector<MachineInstr *, 16> Catches;
178  for (auto &MBB : MF)
179    for (auto &MI : MBB)
180      if (WebAssembly::isCatch(MI.getOpcode()))
181        Catches.push_back(&MI);
182
183  for (auto *Catch : Catches) {
184    MachineBasicBlock *EHPad = getMatchingEHPad(Catch);
185    assert(EHPad && "No matching EH pad for catch");
186    auto InsertPos = EHPad->begin();
187    // Skip EH_LABELs in the beginning of an EH pad if present. We don't use
188    // these labels at the moment, but other targets also seem to have an
189    // EH_LABEL instruction in the beginning of an EH pad.
190    while (InsertPos != EHPad->end() && InsertPos->isEHLabel())
191      InsertPos++;
192    if (InsertPos == Catch)
193      continue;
194    Changed = true;
195    EHPad->insert(InsertPos, Catch->removeFromParent());
196  }
197  return Changed;
198}
199
200// Add catch_all to beginning of cleanup pads.
201bool WebAssemblyLateEHPrepare::addCatchAlls(MachineFunction &MF) {
202  bool Changed = false;
203  const auto &TII = *MF.getSubtarget<WebAssemblySubtarget>().getInstrInfo();
204
205  for (auto &MBB : MF) {
206    if (!MBB.isEHPad())
207      continue;
208    auto InsertPos = MBB.begin();
209    // Skip EH_LABELs in the beginning of an EH pad if present.
210    while (InsertPos != MBB.end() && InsertPos->isEHLabel())
211      InsertPos++;
212    // This runs after hoistCatches(), so we assume that if there is a catch,
213    // that should be the first non-EH-label instruction in an EH pad.
214    if (InsertPos == MBB.end() ||
215        !WebAssembly::isCatch(InsertPos->getOpcode())) {
216      Changed = true;
217      BuildMI(MBB, InsertPos,
218              InsertPos == MBB.end() ? DebugLoc() : InsertPos->getDebugLoc(),
219              TII.get(WebAssembly::CATCH_ALL));
220    }
221  }
222  return Changed;
223}
224
225// Replace pseudo-instructions catchret and cleanupret with br and rethrow
226// respectively.
227bool WebAssemblyLateEHPrepare::replaceFuncletReturns(MachineFunction &MF) {
228  bool Changed = false;
229  const auto &TII = *MF.getSubtarget<WebAssemblySubtarget>().getInstrInfo();
230
231  for (auto &MBB : MF) {
232    auto Pos = MBB.getFirstTerminator();
233    if (Pos == MBB.end())
234      continue;
235    MachineInstr *TI = &*Pos;
236
237    switch (TI->getOpcode()) {
238    case WebAssembly::CATCHRET: {
239      // Replace a catchret with a branch
240      MachineBasicBlock *TBB = TI->getOperand(0).getMBB();
241      if (!MBB.isLayoutSuccessor(TBB))
242        BuildMI(MBB, TI, TI->getDebugLoc(), TII.get(WebAssembly::BR))
243            .addMBB(TBB);
244      TI->eraseFromParent();
245      Changed = true;
246      break;
247    }
248    case WebAssembly::CLEANUPRET: {
249      // Replace a cleanupret with a rethrow. For C++ support, currently
250      // rethrow's immediate argument is always 0 (= the latest exception).
251      BuildMI(MBB, TI, TI->getDebugLoc(), TII.get(WebAssembly::RETHROW))
252          .addImm(0);
253      TI->eraseFromParent();
254      Changed = true;
255      break;
256    }
257    }
258  }
259  return Changed;
260}
261
262// Remove unnecessary unreachables after a throw or rethrow.
263bool WebAssemblyLateEHPrepare::removeUnnecessaryUnreachables(
264    MachineFunction &MF) {
265  bool Changed = false;
266  for (auto &MBB : MF) {
267    for (auto &MI : MBB) {
268      if (MI.getOpcode() != WebAssembly::THROW &&
269          MI.getOpcode() != WebAssembly::RETHROW)
270        continue;
271      Changed = true;
272
273      // The instruction after the throw should be an unreachable or a branch to
274      // another BB that should eventually lead to an unreachable. Delete it
275      // because throw itself is a terminator, and also delete successors if
276      // any.
277      MBB.erase(std::next(MI.getIterator()), MBB.end());
278      SmallVector<MachineBasicBlock *, 8> Succs(MBB.successors());
279      for (auto *Succ : Succs)
280        if (!Succ->isEHPad())
281          MBB.removeSuccessor(Succ);
282      eraseDeadBBsAndChildren(Succs);
283    }
284  }
285
286  return Changed;
287}
288
289// After the stack is unwound due to a thrown exception, the __stack_pointer
290// global can point to an invalid address. This inserts instructions that
291// restore __stack_pointer global.
292bool WebAssemblyLateEHPrepare::restoreStackPointer(MachineFunction &MF) {
293  const auto *FrameLowering = static_cast<const WebAssemblyFrameLowering *>(
294      MF.getSubtarget().getFrameLowering());
295  if (!FrameLowering->needsPrologForEH(MF))
296    return false;
297  bool Changed = false;
298
299  for (auto &MBB : MF) {
300    if (!MBB.isEHPad())
301      continue;
302    Changed = true;
303
304    // Insert __stack_pointer restoring instructions at the beginning of each EH
305    // pad, after the catch instruction. Here it is safe to assume that SP32
306    // holds the latest value of __stack_pointer, because the only exception for
307    // this case is when a function uses the red zone, but that only happens
308    // with leaf functions, and we don't restore __stack_pointer in leaf
309    // functions anyway.
310    auto InsertPos = MBB.begin();
311    // Skip EH_LABELs in the beginning of an EH pad if present.
312    while (InsertPos != MBB.end() && InsertPos->isEHLabel())
313      InsertPos++;
314    assert(InsertPos != MBB.end() &&
315           WebAssembly::isCatch(InsertPos->getOpcode()) &&
316           "catch/catch_all should be present in every EH pad at this point");
317    ++InsertPos; // Skip the catch instruction
318    FrameLowering->writeSPToGlobal(FrameLowering->getSPReg(MF), MF, MBB,
319                                   InsertPos, MBB.begin()->getDebugLoc());
320  }
321  return Changed;
322}
323