WebAssemblyLateEHPrepare.cpp revision 353358
1//=== WebAssemblyLateEHPrepare.cpp - WebAssembly Exception Preparation -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// \brief Does various transformations for exception handling.
11///
12//===----------------------------------------------------------------------===//
13
14#include "MCTargetDesc/WebAssemblyMCTargetDesc.h"
15#include "WebAssembly.h"
16#include "WebAssemblySubtarget.h"
17#include "WebAssemblyUtilities.h"
18#include "llvm/ADT/SmallSet.h"
19#include "llvm/CodeGen/MachineInstrBuilder.h"
20#include "llvm/CodeGen/WasmEHFuncInfo.h"
21#include "llvm/MC/MCAsmInfo.h"
22using namespace llvm;
23
24#define DEBUG_TYPE "wasm-late-eh-prepare"
25
26namespace {
27class WebAssemblyLateEHPrepare final : public MachineFunctionPass {
28  StringRef getPassName() const override {
29    return "WebAssembly Late Prepare Exception";
30  }
31
32  bool runOnMachineFunction(MachineFunction &MF) override;
33  bool addCatches(MachineFunction &MF);
34  bool replaceFuncletReturns(MachineFunction &MF);
35  bool removeUnnecessaryUnreachables(MachineFunction &MF);
36  bool addExceptionExtraction(MachineFunction &MF);
37  bool restoreStackPointer(MachineFunction &MF);
38
39public:
40  static char ID; // Pass identification, replacement for typeid
41  WebAssemblyLateEHPrepare() : MachineFunctionPass(ID) {}
42};
43} // end anonymous namespace
44
45char WebAssemblyLateEHPrepare::ID = 0;
46INITIALIZE_PASS(WebAssemblyLateEHPrepare, DEBUG_TYPE,
47                "WebAssembly Late Exception Preparation", false, false)
48
49FunctionPass *llvm::createWebAssemblyLateEHPrepare() {
50  return new WebAssemblyLateEHPrepare();
51}
52
53// Returns the nearest EH pad that dominates this instruction. This does not use
54// dominator analysis; it just does BFS on its predecessors until arriving at an
55// EH pad. This assumes valid EH scopes so the first EH pad it arrives in all
56// possible search paths should be the same.
57// Returns nullptr in case it does not find any EH pad in the search, or finds
58// multiple different EH pads.
59static MachineBasicBlock *getMatchingEHPad(MachineInstr *MI) {
60  MachineFunction *MF = MI->getParent()->getParent();
61  SmallVector<MachineBasicBlock *, 2> WL;
62  SmallPtrSet<MachineBasicBlock *, 2> Visited;
63  WL.push_back(MI->getParent());
64  MachineBasicBlock *EHPad = nullptr;
65  while (!WL.empty()) {
66    MachineBasicBlock *MBB = WL.pop_back_val();
67    if (Visited.count(MBB))
68      continue;
69    Visited.insert(MBB);
70    if (MBB->isEHPad()) {
71      if (EHPad && EHPad != MBB)
72        return nullptr;
73      EHPad = MBB;
74      continue;
75    }
76    if (MBB == &MF->front())
77      return nullptr;
78    WL.append(MBB->pred_begin(), MBB->pred_end());
79  }
80  return EHPad;
81}
82
83// Erase the specified BBs if the BB does not have any remaining predecessors,
84// and also all its dead children.
85template <typename Container>
86static void eraseDeadBBsAndChildren(const Container &MBBs) {
87  SmallVector<MachineBasicBlock *, 8> WL(MBBs.begin(), MBBs.end());
88  while (!WL.empty()) {
89    MachineBasicBlock *MBB = WL.pop_back_val();
90    if (!MBB->pred_empty())
91      continue;
92    SmallVector<MachineBasicBlock *, 4> Succs(MBB->succ_begin(),
93                                              MBB->succ_end());
94    WL.append(MBB->succ_begin(), MBB->succ_end());
95    for (auto *Succ : Succs)
96      MBB->removeSuccessor(Succ);
97    MBB->eraseFromParent();
98  }
99}
100
101bool WebAssemblyLateEHPrepare::runOnMachineFunction(MachineFunction &MF) {
102  LLVM_DEBUG(dbgs() << "********** Late EH Prepare **********\n"
103                       "********** Function: "
104                    << MF.getName() << '\n');
105
106  if (MF.getTarget().getMCAsmInfo()->getExceptionHandlingType() !=
107      ExceptionHandling::Wasm)
108    return false;
109
110  bool Changed = false;
111  if (MF.getFunction().hasPersonalityFn()) {
112    Changed |= addCatches(MF);
113    Changed |= replaceFuncletReturns(MF);
114  }
115  Changed |= removeUnnecessaryUnreachables(MF);
116  if (MF.getFunction().hasPersonalityFn()) {
117    Changed |= addExceptionExtraction(MF);
118    Changed |= restoreStackPointer(MF);
119  }
120  return Changed;
121}
122
123// Add catch instruction to beginning of catchpads and cleanuppads.
124bool WebAssemblyLateEHPrepare::addCatches(MachineFunction &MF) {
125  bool Changed = false;
126  const auto &TII = *MF.getSubtarget<WebAssemblySubtarget>().getInstrInfo();
127  MachineRegisterInfo &MRI = MF.getRegInfo();
128  for (auto &MBB : MF) {
129    if (MBB.isEHPad()) {
130      Changed = true;
131      auto InsertPos = MBB.begin();
132      if (InsertPos->isEHLabel()) // EH pad starts with an EH label
133        ++InsertPos;
134      unsigned DstReg = MRI.createVirtualRegister(&WebAssembly::EXNREFRegClass);
135      BuildMI(MBB, InsertPos, MBB.begin()->getDebugLoc(),
136              TII.get(WebAssembly::CATCH), DstReg);
137    }
138  }
139  return Changed;
140}
141
142bool WebAssemblyLateEHPrepare::replaceFuncletReturns(MachineFunction &MF) {
143  bool Changed = false;
144  const auto &TII = *MF.getSubtarget<WebAssemblySubtarget>().getInstrInfo();
145
146  for (auto &MBB : MF) {
147    auto Pos = MBB.getFirstTerminator();
148    if (Pos == MBB.end())
149      continue;
150    MachineInstr *TI = &*Pos;
151
152    switch (TI->getOpcode()) {
153    case WebAssembly::CATCHRET: {
154      // Replace a catchret with a branch
155      MachineBasicBlock *TBB = TI->getOperand(0).getMBB();
156      if (!MBB.isLayoutSuccessor(TBB))
157        BuildMI(MBB, TI, TI->getDebugLoc(), TII.get(WebAssembly::BR))
158            .addMBB(TBB);
159      TI->eraseFromParent();
160      Changed = true;
161      break;
162    }
163    case WebAssembly::CLEANUPRET:
164    case WebAssembly::RETHROW_IN_CATCH: {
165      // Replace a cleanupret/rethrow_in_catch with a rethrow
166      auto *EHPad = getMatchingEHPad(TI);
167      auto CatchPos = EHPad->begin();
168      if (CatchPos->isEHLabel()) // EH pad starts with an EH label
169        ++CatchPos;
170      MachineInstr *Catch = &*CatchPos;
171      unsigned ExnReg = Catch->getOperand(0).getReg();
172      BuildMI(MBB, TI, TI->getDebugLoc(), TII.get(WebAssembly::RETHROW))
173          .addReg(ExnReg);
174      TI->eraseFromParent();
175      Changed = true;
176      break;
177    }
178    }
179  }
180  return Changed;
181}
182
183bool WebAssemblyLateEHPrepare::removeUnnecessaryUnreachables(
184    MachineFunction &MF) {
185  bool Changed = false;
186  for (auto &MBB : MF) {
187    for (auto &MI : MBB) {
188      if (MI.getOpcode() != WebAssembly::THROW &&
189          MI.getOpcode() != WebAssembly::RETHROW)
190        continue;
191      Changed = true;
192
193      // The instruction after the throw should be an unreachable or a branch to
194      // another BB that should eventually lead to an unreachable. Delete it
195      // because throw itself is a terminator, and also delete successors if
196      // any.
197      MBB.erase(std::next(MI.getIterator()), MBB.end());
198      SmallVector<MachineBasicBlock *, 8> Succs(MBB.succ_begin(),
199                                                MBB.succ_end());
200      for (auto *Succ : Succs)
201        if (!Succ->isEHPad())
202          MBB.removeSuccessor(Succ);
203      eraseDeadBBsAndChildren(Succs);
204    }
205  }
206
207  return Changed;
208}
209
210// Wasm uses 'br_on_exn' instruction to check the tag of an exception. It takes
211// exnref type object returned by 'catch', and branches to the destination if it
212// matches a given tag. We currently use __cpp_exception symbol to represent the
213// tag for all C++ exceptions.
214//
215// block $l (result i32)
216//   ...
217//   ;; exnref $e is on the stack at this point
218//   br_on_exn $l $e ;; branch to $l with $e's arguments
219//   ...
220// end
221// ;; Here we expect the extracted values are on top of the wasm value stack
222// ... Handle exception using values ...
223//
224// br_on_exn takes an exnref object and branches if it matches the given tag.
225// There can be multiple br_on_exn instructions if we want to match for another
226// tag, but for now we only test for __cpp_exception tag, and if it does not
227// match, i.e., it is a foreign exception, we rethrow it.
228//
229// In the destination BB that's the target of br_on_exn, extracted exception
230// values (in C++'s case a single i32, which represents an exception pointer)
231// are placed on top of the wasm stack. Because we can't model wasm stack in
232// LLVM instruction, we use 'extract_exception' pseudo instruction to retrieve
233// it. The pseudo instruction will be deleted later.
234bool WebAssemblyLateEHPrepare::addExceptionExtraction(MachineFunction &MF) {
235  const auto &TII = *MF.getSubtarget<WebAssemblySubtarget>().getInstrInfo();
236  auto *EHInfo = MF.getWasmEHFuncInfo();
237  SmallVector<MachineInstr *, 16> ExtractInstrs;
238  SmallVector<MachineInstr *, 8> ToDelete;
239  for (auto &MBB : MF) {
240    for (auto &MI : MBB) {
241      if (MI.getOpcode() == WebAssembly::EXTRACT_EXCEPTION_I32) {
242        if (MI.getOperand(0).isDead())
243          ToDelete.push_back(&MI);
244        else
245          ExtractInstrs.push_back(&MI);
246      }
247    }
248  }
249  bool Changed = !ToDelete.empty() || !ExtractInstrs.empty();
250  for (auto *MI : ToDelete)
251    MI->eraseFromParent();
252  if (ExtractInstrs.empty())
253    return Changed;
254
255  // Find terminate pads.
256  SmallSet<MachineBasicBlock *, 8> TerminatePads;
257  for (auto &MBB : MF) {
258    for (auto &MI : MBB) {
259      if (MI.isCall()) {
260        const MachineOperand &CalleeOp = MI.getOperand(0);
261        if (CalleeOp.isGlobal() && CalleeOp.getGlobal()->getName() ==
262                                       WebAssembly::ClangCallTerminateFn)
263          TerminatePads.insert(getMatchingEHPad(&MI));
264      }
265    }
266  }
267
268  for (auto *Extract : ExtractInstrs) {
269    MachineBasicBlock *EHPad = getMatchingEHPad(Extract);
270    assert(EHPad && "No matching EH pad for extract_exception");
271    auto CatchPos = EHPad->begin();
272    if (CatchPos->isEHLabel()) // EH pad starts with an EH label
273      ++CatchPos;
274    MachineInstr *Catch = &*CatchPos;
275
276    if (Catch->getNextNode() != Extract)
277      EHPad->insert(Catch->getNextNode(), Extract->removeFromParent());
278
279    // - Before:
280    // ehpad:
281    //   %exnref:exnref = catch
282    //   %exn:i32 = extract_exception
283    //   ... use exn ...
284    //
285    // - After:
286    // ehpad:
287    //   %exnref:exnref = catch
288    //   br_on_exn %thenbb, $__cpp_exception, %exnref
289    //   br %elsebb
290    // elsebb:
291    //   rethrow
292    // thenbb:
293    //   %exn:i32 = extract_exception
294    //   ... use exn ...
295    unsigned ExnReg = Catch->getOperand(0).getReg();
296    auto *ThenMBB = MF.CreateMachineBasicBlock();
297    auto *ElseMBB = MF.CreateMachineBasicBlock();
298    MF.insert(std::next(MachineFunction::iterator(EHPad)), ElseMBB);
299    MF.insert(std::next(MachineFunction::iterator(ElseMBB)), ThenMBB);
300    ThenMBB->splice(ThenMBB->end(), EHPad, Extract, EHPad->end());
301    ThenMBB->transferSuccessors(EHPad);
302    EHPad->addSuccessor(ThenMBB);
303    EHPad->addSuccessor(ElseMBB);
304
305    DebugLoc DL = Extract->getDebugLoc();
306    const char *CPPExnSymbol = MF.createExternalSymbolName("__cpp_exception");
307    BuildMI(EHPad, DL, TII.get(WebAssembly::BR_ON_EXN))
308        .addMBB(ThenMBB)
309        .addExternalSymbol(CPPExnSymbol)
310        .addReg(ExnReg);
311    BuildMI(EHPad, DL, TII.get(WebAssembly::BR)).addMBB(ElseMBB);
312
313    // When this is a terminate pad with __clang_call_terminate() call, we don't
314    // rethrow it anymore and call __clang_call_terminate() with a nullptr
315    // argument, which will call std::terminate().
316    //
317    // - Before:
318    // ehpad:
319    //   %exnref:exnref = catch
320    //   %exn:i32 = extract_exception
321    //   call @__clang_call_terminate(%exn)
322    //   unreachable
323    //
324    // - After:
325    // ehpad:
326    //   %exnref:exnref = catch
327    //   br_on_exn %thenbb, $__cpp_exception, %exnref
328    //   br %elsebb
329    // elsebb:
330    //   call @__clang_call_terminate(0)
331    //   unreachable
332    // thenbb:
333    //   %exn:i32 = extract_exception
334    //   call @__clang_call_terminate(%exn)
335    //   unreachable
336    if (TerminatePads.count(EHPad)) {
337      Function *ClangCallTerminateFn =
338          MF.getFunction().getParent()->getFunction(
339              WebAssembly::ClangCallTerminateFn);
340      assert(ClangCallTerminateFn &&
341             "There is no __clang_call_terminate() function");
342      BuildMI(ElseMBB, DL, TII.get(WebAssembly::CALL_VOID))
343          .addGlobalAddress(ClangCallTerminateFn)
344          .addImm(0);
345      BuildMI(ElseMBB, DL, TII.get(WebAssembly::UNREACHABLE));
346
347    } else {
348      BuildMI(ElseMBB, DL, TII.get(WebAssembly::RETHROW)).addReg(ExnReg);
349      if (EHInfo->hasEHPadUnwindDest(EHPad))
350        ElseMBB->addSuccessor(EHInfo->getEHPadUnwindDest(EHPad));
351    }
352  }
353
354  return true;
355}
356
357// After the stack is unwound due to a thrown exception, the __stack_pointer
358// global can point to an invalid address. This inserts instructions that
359// restore __stack_pointer global.
360bool WebAssemblyLateEHPrepare::restoreStackPointer(MachineFunction &MF) {
361  const auto *FrameLowering = static_cast<const WebAssemblyFrameLowering *>(
362      MF.getSubtarget().getFrameLowering());
363  if (!FrameLowering->needsPrologForEH(MF))
364    return false;
365  bool Changed = false;
366
367  for (auto &MBB : MF) {
368    if (!MBB.isEHPad())
369      continue;
370    Changed = true;
371
372    // Insert __stack_pointer restoring instructions at the beginning of each EH
373    // pad, after the catch instruction. Here it is safe to assume that SP32
374    // holds the latest value of __stack_pointer, because the only exception for
375    // this case is when a function uses the red zone, but that only happens
376    // with leaf functions, and we don't restore __stack_pointer in leaf
377    // functions anyway.
378    auto InsertPos = MBB.begin();
379    if (InsertPos->isEHLabel()) // EH pad starts with an EH label
380      ++InsertPos;
381    if (InsertPos->getOpcode() == WebAssembly::CATCH)
382      ++InsertPos;
383    FrameLowering->writeSPToGlobal(WebAssembly::SP32, MF, MBB, InsertPos,
384                                   MBB.begin()->getDebugLoc());
385  }
386  return Changed;
387}
388