1193323Sed//===- UnifyFunctionExitNodes.cpp - Make all functions have a single exit -===//
2193323Sed//
3193323Sed//                     The LLVM Compiler Infrastructure
4193323Sed//
5193323Sed// This file is distributed under the University of Illinois Open Source
6193323Sed// License. See LICENSE.TXT for details.
7193323Sed//
8193323Sed//===----------------------------------------------------------------------===//
9193323Sed//
10193323Sed// This pass is used to ensure that functions have at most one return
11193323Sed// instruction in them.  Additionally, it keeps track of which node is the new
12193323Sed// exit node of the CFG.  If there are no exit nodes in the CFG, the getExitNode
13193323Sed// method will return a null pointer.
14193323Sed//
15193323Sed//===----------------------------------------------------------------------===//
16193323Sed
17193323Sed#include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h"
18249423Sdim#include "llvm/ADT/StringExtras.h"
19249423Sdim#include "llvm/IR/BasicBlock.h"
20249423Sdim#include "llvm/IR/Function.h"
21249423Sdim#include "llvm/IR/Instructions.h"
22249423Sdim#include "llvm/IR/Type.h"
23193323Sed#include "llvm/Transforms/Scalar.h"
24193323Sedusing namespace llvm;
25193323Sed
26193323Sedchar UnifyFunctionExitNodes::ID = 0;
27212904SdimINITIALIZE_PASS(UnifyFunctionExitNodes, "mergereturn",
28218893Sdim                "Unify function exit nodes", false, false)
29193323Sed
30193323SedPass *llvm::createUnifyFunctionExitNodesPass() {
31193323Sed  return new UnifyFunctionExitNodes();
32193323Sed}
33193323Sed
34193323Sedvoid UnifyFunctionExitNodes::getAnalysisUsage(AnalysisUsage &AU) const{
35193323Sed  // We preserve the non-critical-edgeness property
36193323Sed  AU.addPreservedID(BreakCriticalEdgesID);
37193323Sed  // This is a cluster of orthogonal Transforms
38212904Sdim  AU.addPreserved("mem2reg");
39193323Sed  AU.addPreservedID(LowerSwitchID);
40193323Sed}
41193323Sed
42193323Sed// UnifyAllExitNodes - Unify all exit nodes of the CFG by creating a new
43193323Sed// BasicBlock, and converting all returns to unconditional branches to this
44193323Sed// new basic block.  The singular exit node is returned.
45193323Sed//
46193323Sed// If there are no return stmts in the Function, a null pointer is returned.
47193323Sed//
48193323Sedbool UnifyFunctionExitNodes::runOnFunction(Function &F) {
49193323Sed  // Loop over all of the blocks in a function, tracking all of the blocks that
50193323Sed  // return.
51193323Sed  //
52193323Sed  std::vector<BasicBlock*> ReturningBlocks;
53193323Sed  std::vector<BasicBlock*> UnreachableBlocks;
54193323Sed  for(Function::iterator I = F.begin(), E = F.end(); I != E; ++I)
55193323Sed    if (isa<ReturnInst>(I->getTerminator()))
56193323Sed      ReturningBlocks.push_back(I);
57193323Sed    else if (isa<UnreachableInst>(I->getTerminator()))
58193323Sed      UnreachableBlocks.push_back(I);
59193323Sed
60193323Sed  // Then unreachable blocks.
61193323Sed  if (UnreachableBlocks.empty()) {
62193323Sed    UnreachableBlock = 0;
63193323Sed  } else if (UnreachableBlocks.size() == 1) {
64193323Sed    UnreachableBlock = UnreachableBlocks.front();
65193323Sed  } else {
66198090Srdivacky    UnreachableBlock = BasicBlock::Create(F.getContext(),
67198090Srdivacky                                          "UnifiedUnreachableBlock", &F);
68198090Srdivacky    new UnreachableInst(F.getContext(), UnreachableBlock);
69193323Sed
70193323Sed    for (std::vector<BasicBlock*>::iterator I = UnreachableBlocks.begin(),
71193323Sed           E = UnreachableBlocks.end(); I != E; ++I) {
72193323Sed      BasicBlock *BB = *I;
73193323Sed      BB->getInstList().pop_back();  // Remove the unreachable inst.
74193323Sed      BranchInst::Create(UnreachableBlock, BB);
75193323Sed    }
76193323Sed  }
77193323Sed
78193323Sed  // Now handle return blocks.
79193323Sed  if (ReturningBlocks.empty()) {
80193323Sed    ReturnBlock = 0;
81193323Sed    return false;                          // No blocks return
82193323Sed  } else if (ReturningBlocks.size() == 1) {
83193323Sed    ReturnBlock = ReturningBlocks.front(); // Already has a single return block
84193323Sed    return false;
85193323Sed  }
86193323Sed
87193323Sed  // Otherwise, we need to insert a new basic block into the function, add a PHI
88193323Sed  // nodes (if the function returns values), and convert all of the return
89193323Sed  // instructions into unconditional branches.
90193323Sed  //
91198090Srdivacky  BasicBlock *NewRetBlock = BasicBlock::Create(F.getContext(),
92198090Srdivacky                                               "UnifiedReturnBlock", &F);
93193323Sed
94193323Sed  PHINode *PN = 0;
95202375Srdivacky  if (F.getReturnType()->isVoidTy()) {
96198090Srdivacky    ReturnInst::Create(F.getContext(), NULL, NewRetBlock);
97193323Sed  } else {
98193323Sed    // If the function doesn't return void... add a PHI node to the block...
99221345Sdim    PN = PHINode::Create(F.getReturnType(), ReturningBlocks.size(),
100221345Sdim                         "UnifiedRetVal");
101193323Sed    NewRetBlock->getInstList().push_back(PN);
102198090Srdivacky    ReturnInst::Create(F.getContext(), PN, NewRetBlock);
103193323Sed  }
104193323Sed
105193323Sed  // Loop over all of the blocks, replacing the return instruction with an
106193323Sed  // unconditional branch.
107193323Sed  //
108193323Sed  for (std::vector<BasicBlock*>::iterator I = ReturningBlocks.begin(),
109193323Sed         E = ReturningBlocks.end(); I != E; ++I) {
110193323Sed    BasicBlock *BB = *I;
111193323Sed
112193323Sed    // Add an incoming element to the PHI node for every return instruction that
113193323Sed    // is merging into this new block...
114193323Sed    if (PN)
115193323Sed      PN->addIncoming(BB->getTerminator()->getOperand(0), BB);
116193323Sed
117193323Sed    BB->getInstList().pop_back();  // Remove the return insn
118193323Sed    BranchInst::Create(NewRetBlock, BB);
119193323Sed  }
120193323Sed  ReturnBlock = NewRetBlock;
121193323Sed  return true;
122193323Sed}
123