LowerExpectIntrinsic.cpp revision 263508
121308Sache//===- LowerExpectIntrinsic.cpp - Lower expect intrinsic ------------------===//
221308Sache//
321308Sache//                     The LLVM Compiler Infrastructure
421308Sache//
521308Sache// This file is distributed under the University of Illinois Open Source
621308Sache// License. See LICENSE.TXT for details.
721308Sache//
821308Sache//===----------------------------------------------------------------------===//
921308Sache//
1058310Sache// This pass lowers the 'expect' intrinsic to LLVM metadata.
1121308Sache//
1221308Sache//===----------------------------------------------------------------------===//
1321308Sache
1421308Sache#define DEBUG_TYPE "lower-expect-intrinsic"
1521308Sache#include "llvm/Transforms/Scalar.h"
1621308Sache#include "llvm/ADT/Statistic.h"
1721308Sache#include "llvm/IR/BasicBlock.h"
1821308Sache#include "llvm/IR/Constants.h"
1921308Sache#include "llvm/IR/Function.h"
2021308Sache#include "llvm/IR/Instructions.h"
2158310Sache#include "llvm/IR/Intrinsics.h"
2221308Sache#include "llvm/IR/LLVMContext.h"
2321308Sache#include "llvm/IR/MDBuilder.h"
24136644Sache#include "llvm/IR/Metadata.h"
25136644Sache#include "llvm/Pass.h"
26136644Sache#include "llvm/Support/CommandLine.h"
27136644Sache#include "llvm/Support/Debug.h"
2821308Sache#include <vector>
2921308Sache
3021308Sacheusing namespace llvm;
3121308Sache
3221308SacheSTATISTIC(IfHandled, "Number of 'expect' intrinsic instructions handled");
3321308Sache
3421308Sachestatic cl::opt<uint32_t>
3521308SacheLikelyBranchWeight("likely-branch-weight", cl::Hidden, cl::init(64),
3621308Sache                   cl::desc("Weight of the branch likely to be taken (default = 64)"));
37119610Sachestatic cl::opt<uint32_t>
38119610SacheUnlikelyBranchWeight("unlikely-branch-weight", cl::Hidden, cl::init(4),
39119610Sache                   cl::desc("Weight of the branch unlikely to be taken (default = 4)"));
40119610Sache
4121308Sachenamespace {
4221308Sache
4321308Sache  class LowerExpectIntrinsic : public FunctionPass {
4421308Sache
4521308Sache    bool HandleSwitchExpect(SwitchInst *SI);
4621308Sache
4721308Sache    bool HandleIfExpect(BranchInst *BI);
4821308Sache
4921308Sache  public:
5021308Sache    static char ID;
5121308Sache    LowerExpectIntrinsic() : FunctionPass(ID) {
5221308Sache      initializeLowerExpectIntrinsicPass(*PassRegistry::getPassRegistry());
5321308Sache    }
5421308Sache
5521308Sache    bool runOnFunction(Function &F);
5621308Sache  };
5721308Sache}
5821308Sache
5921308Sache
6021308Sachebool LowerExpectIntrinsic::HandleSwitchExpect(SwitchInst *SI) {
6121308Sache  CallInst *CI = dyn_cast<CallInst>(SI->getCondition());
6221308Sache  if (!CI)
6358310Sache    return false;
6421308Sache
65119610Sache  Function *Fn = CI->getCalledFunction();
6621308Sache  if (!Fn || Fn->getIntrinsicID() != Intrinsic::expect)
6721308Sache    return false;
6821308Sache
6921308Sache  Value *ArgValue = CI->getArgOperand(0);
7021308Sache  ConstantInt *ExpectedValue = dyn_cast<ConstantInt>(CI->getArgOperand(1));
7121308Sache  if (!ExpectedValue)
7221308Sache    return false;
7321308Sache
7421308Sache  SwitchInst::CaseIt Case = SI->findCaseValue(ExpectedValue);
7575406Sache  unsigned n = SI->getNumCases(); // +1 for default case.
7675406Sache  std::vector<uint32_t> Weights(n + 1);
7758310Sache
7858310Sache  Weights[0] = Case == SI->case_default() ? LikelyBranchWeight
7958310Sache                                          : UnlikelyBranchWeight;
8058310Sache  for (unsigned i = 0; i != n; ++i)
8158310Sache    Weights[i + 1] = i == Case.getCaseIndex() ? LikelyBranchWeight
8258310Sache                                              : UnlikelyBranchWeight;
8358310Sache
8458310Sache  SI->setMetadata(LLVMContext::MD_prof,
8558310Sache                  MDBuilder(CI->getContext()).createBranchWeights(Weights));
8658310Sache
8758310Sache  SI->setCondition(ArgValue);
8858310Sache  return true;
8958310Sache}
9058310Sache
9158310Sache
9258310Sachebool LowerExpectIntrinsic::HandleIfExpect(BranchInst *BI) {
9358310Sache  if (BI->isUnconditional())
9458310Sache    return false;
9558310Sache
9621308Sache  // Handle non-optimized IR code like:
9721308Sache  //   %expval = call i64 @llvm.expect.i64.i64(i64 %conv1, i64 1)
9875406Sache  //   %tobool = icmp ne i64 %expval, 0
9975406Sache  //   br i1 %tobool, label %if.then, label %if.end
10075406Sache
10175406Sache  ICmpInst *CmpI = dyn_cast<ICmpInst>(BI->getCondition());
10275406Sache  if (!CmpI || CmpI->getPredicate() != CmpInst::ICMP_NE)
10375406Sache    return false;
10475406Sache
10575406Sache  CallInst *CI = dyn_cast<CallInst>(CmpI->getOperand(0));
10675406Sache  if (!CI)
10775406Sache    return false;
10875406Sache
10975406Sache  Function *Fn = CI->getCalledFunction();
11021308Sache  if (!Fn || Fn->getIntrinsicID() != Intrinsic::expect)
11121308Sache    return false;
11221308Sache
11321308Sache  Value *ArgValue = CI->getArgOperand(0);
114119610Sache  ConstantInt *ExpectedValue = dyn_cast<ConstantInt>(CI->getArgOperand(1));
11521308Sache  if (!ExpectedValue)
11621308Sache    return false;
11721308Sache
11821308Sache  MDBuilder MDB(CI->getContext());
11921308Sache  MDNode *Node;
12021308Sache
12121308Sache  // If expect value is equal to 1 it means that we are more likely to take
122119610Sache  // branch 0, in other case more likely is branch 1.
12321308Sache  if (ExpectedValue->isOne())
12421308Sache    Node = MDB.createBranchWeights(LikelyBranchWeight, UnlikelyBranchWeight);
12521308Sache  else
12621308Sache    Node = MDB.createBranchWeights(UnlikelyBranchWeight, LikelyBranchWeight);
12721308Sache
12821308Sache  BI->setMetadata(LLVMContext::MD_prof, Node);
12921308Sache
13021308Sache  CmpI->setOperand(0, ArgValue);
13121308Sache  return true;
13221308Sache}
13335486Sache
13475406Sache
13521308Sachebool LowerExpectIntrinsic::runOnFunction(Function &F) {
13621308Sache  for (Function::iterator I = F.begin(), E = F.end(); I != E;) {
13721308Sache    BasicBlock *BB = I++;
13821308Sache
13921308Sache    // Create "block_weights" metadata.
14021308Sache    if (BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator())) {
14121308Sache      if (HandleIfExpect(BI))
142119610Sache        IfHandled++;
14321308Sache    } else if (SwitchInst *SI = dyn_cast<SwitchInst>(BB->getTerminator())) {
14421308Sache      if (HandleSwitchExpect(SI))
14521308Sache        IfHandled++;
14621308Sache    }
14721308Sache
14821308Sache    // remove llvm.expect intrinsics.
14921308Sache    for (BasicBlock::iterator BI = BB->begin(), BE = BB->end();
15021308Sache         BI != BE; ) {
15121308Sache      CallInst *CI = dyn_cast<CallInst>(BI++);
15221308Sache      if (!CI)
15321308Sache        continue;
15421308Sache
15521308Sache      Function *Fn = CI->getCalledFunction();
15621308Sache      if (Fn && Fn->getIntrinsicID() == Intrinsic::expect) {
15721308Sache        Value *Exp = CI->getArgOperand(0);
15821308Sache        CI->replaceAllUsesWith(Exp);
15921308Sache        CI->eraseFromParent();
16021308Sache      }
16121308Sache    }
16221308Sache  }
16321308Sache
16421308Sache  return false;
16521308Sache}
16621308Sache
16721308Sache
16821308Sachechar LowerExpectIntrinsic::ID = 0;
16921308SacheINITIALIZE_PASS(LowerExpectIntrinsic, "lower-expect", "Lower 'expect' "
17021308Sache                "Intrinsics", false, false)
17121308Sache
17221308SacheFunctionPass *llvm::createLowerExpectIntrinsicPass() {
17321308Sache  return new LowerExpectIntrinsic();
17421308Sache}
17521308Sache