1//===- UnifyFunctionExitNodes.cpp - Make all functions have a single exit -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass is used to ensure that functions have at most one return
10// instruction in them.  Additionally, it keeps track of which node is the new
11// exit node of the CFG.  If there are no exit nodes in the CFG, the getExitNode
12// method will return a null pointer.
13//
14//===----------------------------------------------------------------------===//
15
16#include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h"
17#include "llvm/IR/BasicBlock.h"
18#include "llvm/IR/Function.h"
19#include "llvm/IR/Instructions.h"
20#include "llvm/IR/Type.h"
21#include "llvm/InitializePasses.h"
22#include "llvm/Transforms/Utils.h"
23using namespace llvm;
24
25char UnifyFunctionExitNodes::ID = 0;
26
27UnifyFunctionExitNodes::UnifyFunctionExitNodes() : FunctionPass(ID) {
28  initializeUnifyFunctionExitNodesPass(*PassRegistry::getPassRegistry());
29}
30
31INITIALIZE_PASS(UnifyFunctionExitNodes, "mergereturn",
32                "Unify function exit nodes", false, false)
33
34Pass *llvm::createUnifyFunctionExitNodesPass() {
35  return new UnifyFunctionExitNodes();
36}
37
38void UnifyFunctionExitNodes::getAnalysisUsage(AnalysisUsage &AU) const{
39  // We preserve the non-critical-edgeness property
40  AU.addPreservedID(BreakCriticalEdgesID);
41  // This is a cluster of orthogonal Transforms
42  AU.addPreservedID(LowerSwitchID);
43}
44
45// UnifyAllExitNodes - Unify all exit nodes of the CFG by creating a new
46// BasicBlock, and converting all returns to unconditional branches to this
47// new basic block.  The singular exit node is returned.
48//
49// If there are no return stmts in the Function, a null pointer is returned.
50//
51bool UnifyFunctionExitNodes::runOnFunction(Function &F) {
52  // Loop over all of the blocks in a function, tracking all of the blocks that
53  // return.
54  //
55  std::vector<BasicBlock*> ReturningBlocks;
56  std::vector<BasicBlock*> UnreachableBlocks;
57  for (BasicBlock &I : F)
58    if (isa<ReturnInst>(I.getTerminator()))
59      ReturningBlocks.push_back(&I);
60    else if (isa<UnreachableInst>(I.getTerminator()))
61      UnreachableBlocks.push_back(&I);
62
63  // Then unreachable blocks.
64  if (UnreachableBlocks.empty()) {
65    UnreachableBlock = nullptr;
66  } else if (UnreachableBlocks.size() == 1) {
67    UnreachableBlock = UnreachableBlocks.front();
68  } else {
69    UnreachableBlock = BasicBlock::Create(F.getContext(),
70                                          "UnifiedUnreachableBlock", &F);
71    new UnreachableInst(F.getContext(), UnreachableBlock);
72
73    for (BasicBlock *BB : UnreachableBlocks) {
74      BB->getInstList().pop_back();  // Remove the unreachable inst.
75      BranchInst::Create(UnreachableBlock, BB);
76    }
77  }
78
79  // Now handle return blocks.
80  if (ReturningBlocks.empty()) {
81    ReturnBlock = nullptr;
82    return false;                          // No blocks return
83  } else if (ReturningBlocks.size() == 1) {
84    ReturnBlock = ReturningBlocks.front(); // Already has a single return block
85    return false;
86  }
87
88  // Otherwise, we need to insert a new basic block into the function, add a PHI
89  // nodes (if the function returns values), and convert all of the return
90  // instructions into unconditional branches.
91  //
92  BasicBlock *NewRetBlock = BasicBlock::Create(F.getContext(),
93                                               "UnifiedReturnBlock", &F);
94
95  PHINode *PN = nullptr;
96  if (F.getReturnType()->isVoidTy()) {
97    ReturnInst::Create(F.getContext(), nullptr, NewRetBlock);
98  } else {
99    // If the function doesn't return void... add a PHI node to the block...
100    PN = PHINode::Create(F.getReturnType(), ReturningBlocks.size(),
101                         "UnifiedRetVal");
102    NewRetBlock->getInstList().push_back(PN);
103    ReturnInst::Create(F.getContext(), PN, NewRetBlock);
104  }
105
106  // Loop over all of the blocks, replacing the return instruction with an
107  // unconditional branch.
108  //
109  for (BasicBlock *BB : ReturningBlocks) {
110    // Add an incoming element to the PHI node for every return instruction that
111    // is merging into this new block...
112    if (PN)
113      PN->addIncoming(BB->getTerminator()->getOperand(0), BB);
114
115    BB->getInstList().pop_back();  // Remove the return insn
116    BranchInst::Create(NewRetBlock, BB);
117  }
118  ReturnBlock = NewRetBlock;
119  return true;
120}
121