CostModel.cpp revision 256281
1//===- CostModel.cpp ------ Cost Model Analysis ---------------------------===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// This file defines the cost model analysis. It provides a very basic cost
11// estimation for LLVM-IR. This analysis uses the services of the codegen
12// to approximate the cost of any IR instruction when lowered to machine
13// instructions. The cost results are unit-less and the cost number represents
14// the throughput of the machine assuming that all loads hit the cache, all
15// branches are predicted, etc. The cost numbers can be added in order to
16// compare two or more transformation alternatives.
17//
18//===----------------------------------------------------------------------===//
19
20#define CM_NAME "cost-model"
21#define DEBUG_TYPE CM_NAME
22#include "llvm/Analysis/Passes.h"
23#include "llvm/Analysis/TargetTransformInfo.h"
24#include "llvm/IR/Function.h"
25#include "llvm/IR/Instructions.h"
26#include "llvm/IR/IntrinsicInst.h"
27#include "llvm/IR/Value.h"
28#include "llvm/Pass.h"
29#include "llvm/Support/Debug.h"
30#include "llvm/Support/raw_ostream.h"
31using namespace llvm;
32
33namespace {
34  class CostModelAnalysis : public FunctionPass {
35
36  public:
37    static char ID; // Class identification, replacement for typeinfo
38    CostModelAnalysis() : FunctionPass(ID), F(0), TTI(0) {
39      initializeCostModelAnalysisPass(
40        *PassRegistry::getPassRegistry());
41    }
42
43    /// Returns the expected cost of the instruction.
44    /// Returns -1 if the cost is unknown.
45    /// Note, this method does not cache the cost calculation and it
46    /// can be expensive in some cases.
47    unsigned getInstructionCost(const Instruction *I) const;
48
49  private:
50    virtual void getAnalysisUsage(AnalysisUsage &AU) const;
51    virtual bool runOnFunction(Function &F);
52    virtual void print(raw_ostream &OS, const Module*) const;
53
54    /// The function that we analyze.
55    Function *F;
56    /// Target information.
57    const TargetTransformInfo *TTI;
58  };
59}  // End of anonymous namespace
60
61// Register this pass.
62char CostModelAnalysis::ID = 0;
63static const char cm_name[] = "Cost Model Analysis";
64INITIALIZE_PASS_BEGIN(CostModelAnalysis, CM_NAME, cm_name, false, true)
65INITIALIZE_PASS_END  (CostModelAnalysis, CM_NAME, cm_name, false, true)
66
67FunctionPass *llvm::createCostModelAnalysisPass() {
68  return new CostModelAnalysis();
69}
70
71void
72CostModelAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
73  AU.setPreservesAll();
74}
75
76bool
77CostModelAnalysis::runOnFunction(Function &F) {
78 this->F = &F;
79 TTI = getAnalysisIfAvailable<TargetTransformInfo>();
80
81 return false;
82}
83
84static bool isReverseVectorMask(SmallVector<int, 16> &Mask) {
85  for (unsigned i = 0, MaskSize = Mask.size(); i < MaskSize; ++i)
86    if (Mask[i] > 0 && Mask[i] != (int)(MaskSize - 1 - i))
87      return false;
88  return true;
89}
90
91static TargetTransformInfo::OperandValueKind getOperandInfo(Value *V) {
92  TargetTransformInfo::OperandValueKind OpInfo =
93    TargetTransformInfo::OK_AnyValue;
94
95  // Check for a splat of a constant.
96  ConstantDataVector *CDV = 0;
97  if ((CDV = dyn_cast<ConstantDataVector>(V)))
98    if (CDV->getSplatValue() != NULL)
99      OpInfo = TargetTransformInfo::OK_UniformConstantValue;
100  ConstantVector *CV = 0;
101  if ((CV = dyn_cast<ConstantVector>(V)))
102    if (CV->getSplatValue() != NULL)
103      OpInfo = TargetTransformInfo::OK_UniformConstantValue;
104
105  return OpInfo;
106}
107
108unsigned CostModelAnalysis::getInstructionCost(const Instruction *I) const {
109  if (!TTI)
110    return -1;
111
112  switch (I->getOpcode()) {
113  case Instruction::GetElementPtr:{
114    Type *ValTy = I->getOperand(0)->getType()->getPointerElementType();
115    return TTI->getAddressComputationCost(ValTy);
116  }
117
118  case Instruction::Ret:
119  case Instruction::PHI:
120  case Instruction::Br: {
121    return TTI->getCFInstrCost(I->getOpcode());
122  }
123  case Instruction::Add:
124  case Instruction::FAdd:
125  case Instruction::Sub:
126  case Instruction::FSub:
127  case Instruction::Mul:
128  case Instruction::FMul:
129  case Instruction::UDiv:
130  case Instruction::SDiv:
131  case Instruction::FDiv:
132  case Instruction::URem:
133  case Instruction::SRem:
134  case Instruction::FRem:
135  case Instruction::Shl:
136  case Instruction::LShr:
137  case Instruction::AShr:
138  case Instruction::And:
139  case Instruction::Or:
140  case Instruction::Xor: {
141    TargetTransformInfo::OperandValueKind Op1VK =
142      getOperandInfo(I->getOperand(0));
143    TargetTransformInfo::OperandValueKind Op2VK =
144      getOperandInfo(I->getOperand(1));
145    return TTI->getArithmeticInstrCost(I->getOpcode(), I->getType(), Op1VK,
146                                       Op2VK);
147  }
148  case Instruction::Select: {
149    const SelectInst *SI = cast<SelectInst>(I);
150    Type *CondTy = SI->getCondition()->getType();
151    return TTI->getCmpSelInstrCost(I->getOpcode(), I->getType(), CondTy);
152  }
153  case Instruction::ICmp:
154  case Instruction::FCmp: {
155    Type *ValTy = I->getOperand(0)->getType();
156    return TTI->getCmpSelInstrCost(I->getOpcode(), ValTy);
157  }
158  case Instruction::Store: {
159    const StoreInst *SI = cast<StoreInst>(I);
160    Type *ValTy = SI->getValueOperand()->getType();
161    return TTI->getMemoryOpCost(I->getOpcode(), ValTy,
162                                 SI->getAlignment(),
163                                 SI->getPointerAddressSpace());
164  }
165  case Instruction::Load: {
166    const LoadInst *LI = cast<LoadInst>(I);
167    return TTI->getMemoryOpCost(I->getOpcode(), I->getType(),
168                                 LI->getAlignment(),
169                                 LI->getPointerAddressSpace());
170  }
171  case Instruction::ZExt:
172  case Instruction::SExt:
173  case Instruction::FPToUI:
174  case Instruction::FPToSI:
175  case Instruction::FPExt:
176  case Instruction::PtrToInt:
177  case Instruction::IntToPtr:
178  case Instruction::SIToFP:
179  case Instruction::UIToFP:
180  case Instruction::Trunc:
181  case Instruction::FPTrunc:
182  case Instruction::BitCast: {
183    Type *SrcTy = I->getOperand(0)->getType();
184    return TTI->getCastInstrCost(I->getOpcode(), I->getType(), SrcTy);
185  }
186  case Instruction::ExtractElement: {
187    const ExtractElementInst * EEI = cast<ExtractElementInst>(I);
188    ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1));
189    unsigned Idx = -1;
190    if (CI)
191      Idx = CI->getZExtValue();
192    return TTI->getVectorInstrCost(I->getOpcode(),
193                                   EEI->getOperand(0)->getType(), Idx);
194  }
195  case Instruction::InsertElement: {
196      const InsertElementInst * IE = cast<InsertElementInst>(I);
197      ConstantInt *CI = dyn_cast<ConstantInt>(IE->getOperand(2));
198      unsigned Idx = -1;
199      if (CI)
200        Idx = CI->getZExtValue();
201      return TTI->getVectorInstrCost(I->getOpcode(),
202                                     IE->getType(), Idx);
203    }
204  case Instruction::ShuffleVector: {
205    const ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(I);
206    Type *VecTypOp0 = Shuffle->getOperand(0)->getType();
207    unsigned NumVecElems = VecTypOp0->getVectorNumElements();
208    SmallVector<int, 16> Mask = Shuffle->getShuffleMask();
209
210    if (NumVecElems == Mask.size() && isReverseVectorMask(Mask))
211      return TTI->getShuffleCost(TargetTransformInfo::SK_Reverse, VecTypOp0, 0,
212                                 0);
213    return -1;
214  }
215  case Instruction::Call:
216    if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
217      SmallVector<Type*, 4> Tys;
218      for (unsigned J = 0, JE = II->getNumArgOperands(); J != JE; ++J)
219        Tys.push_back(II->getArgOperand(J)->getType());
220
221      return TTI->getIntrinsicInstrCost(II->getIntrinsicID(), II->getType(),
222                                        Tys);
223    }
224    return -1;
225  default:
226    // We don't have any information on this instruction.
227    return -1;
228  }
229}
230
231void CostModelAnalysis::print(raw_ostream &OS, const Module*) const {
232  if (!F)
233    return;
234
235  for (Function::iterator B = F->begin(), BE = F->end(); B != BE; ++B) {
236    for (BasicBlock::iterator it = B->begin(), e = B->end(); it != e; ++it) {
237      Instruction *Inst = it;
238      unsigned Cost = getInstructionCost(Inst);
239      if (Cost != (unsigned)-1)
240        OS << "Cost Model: Found an estimated cost of " << Cost;
241      else
242        OS << "Cost Model: Unknown cost";
243
244      OS << " for instruction: "<< *Inst << "\n";
245    }
246  }
247}
248