1239310Sdim//===- NVPTXLowerAggrCopies.cpp - ------------------------------*- C++ -*--===//
2239310Sdim//
3239310Sdim//                     The LLVM Compiler Infrastructure
4239310Sdim//
5239310Sdim// This file is distributed under the University of Illinois Open Source
6239310Sdim// License. See LICENSE.TXT for details.
7239310Sdim//
8239310Sdim//===----------------------------------------------------------------------===//
9239310Sdim// Lower aggregate copies, memset, memcpy, memmov intrinsics into loops when
10239310Sdim// the size is large or is not a compile-time constant.
11239310Sdim//
12239310Sdim//===----------------------------------------------------------------------===//
13239310Sdim
14239310Sdim#include "NVPTXLowerAggrCopies.h"
15249423Sdim#include "llvm/IR/Constants.h"
16249423Sdim#include "llvm/IR/DataLayout.h"
17249423Sdim#include "llvm/IR/Function.h"
18249423Sdim#include "llvm/IR/IRBuilder.h"
19249423Sdim#include "llvm/IR/Instructions.h"
20249423Sdim#include "llvm/IR/IntrinsicInst.h"
21249423Sdim#include "llvm/IR/Intrinsics.h"
22249423Sdim#include "llvm/IR/LLVMContext.h"
23249423Sdim#include "llvm/IR/Module.h"
24239310Sdim#include "llvm/Support/InstIterator.h"
25239310Sdim
26239310Sdimusing namespace llvm;
27239310Sdim
28249423Sdimnamespace llvm { FunctionPass *createLowerAggrCopies(); }
29239310Sdim
30239310Sdimchar NVPTXLowerAggrCopies::ID = 0;
31239310Sdim
32239310Sdim// Lower MemTransferInst or load-store pair to loop
33249423Sdimstatic void convertTransferToLoop(
34249423Sdim    Instruction *splitAt, Value *srcAddr, Value *dstAddr, Value *len,
35249423Sdim    //unsigned numLoads,
36249423Sdim    bool srcVolatile, bool dstVolatile, LLVMContext &Context, Function &F) {
37239310Sdim  Type *indType = len->getType();
38239310Sdim
39239310Sdim  BasicBlock *origBB = splitAt->getParent();
40239310Sdim  BasicBlock *newBB = splitAt->getParent()->splitBasicBlock(splitAt, "split");
41239310Sdim  BasicBlock *loopBB = BasicBlock::Create(Context, "loadstoreloop", &F, newBB);
42239310Sdim
43239310Sdim  origBB->getTerminator()->setSuccessor(0, loopBB);
44239310Sdim  IRBuilder<> builder(origBB, origBB->getTerminator());
45239310Sdim
46239310Sdim  // srcAddr and dstAddr are expected to be pointer types,
47239310Sdim  // so no check is made here.
48249423Sdim  unsigned srcAS = dyn_cast<PointerType>(srcAddr->getType())->getAddressSpace();
49249423Sdim  unsigned dstAS = dyn_cast<PointerType>(dstAddr->getType())->getAddressSpace();
50239310Sdim
51239310Sdim  // Cast pointers to (char *)
52239310Sdim  srcAddr = builder.CreateBitCast(srcAddr, Type::getInt8PtrTy(Context, srcAS));
53239310Sdim  dstAddr = builder.CreateBitCast(dstAddr, Type::getInt8PtrTy(Context, dstAS));
54239310Sdim
55239310Sdim  IRBuilder<> loop(loopBB);
56239310Sdim  // The loop index (ind) is a phi node.
57239310Sdim  PHINode *ind = loop.CreatePHI(indType, 0);
58239310Sdim  // Incoming value for ind is 0
59239310Sdim  ind->addIncoming(ConstantInt::get(indType, 0), origBB);
60239310Sdim
61239310Sdim  // load from srcAddr+ind
62239310Sdim  Value *val = loop.CreateLoad(loop.CreateGEP(srcAddr, ind), srcVolatile);
63239310Sdim  // store at dstAddr+ind
64239310Sdim  loop.CreateStore(val, loop.CreateGEP(dstAddr, ind), dstVolatile);
65239310Sdim
66239310Sdim  // The value for ind coming from backedge is (ind + 1)
67239310Sdim  Value *newind = loop.CreateAdd(ind, ConstantInt::get(indType, 1));
68239310Sdim  ind->addIncoming(newind, loopBB);
69239310Sdim
70239310Sdim  loop.CreateCondBr(loop.CreateICmpULT(newind, len), loopBB, newBB);
71239310Sdim}
72239310Sdim
73239310Sdim// Lower MemSetInst to loop
74239310Sdimstatic void convertMemSetToLoop(Instruction *splitAt, Value *dstAddr,
75239310Sdim                                Value *len, Value *val, LLVMContext &Context,
76239310Sdim                                Function &F) {
77239310Sdim  BasicBlock *origBB = splitAt->getParent();
78239310Sdim  BasicBlock *newBB = splitAt->getParent()->splitBasicBlock(splitAt, "split");
79239310Sdim  BasicBlock *loopBB = BasicBlock::Create(Context, "loadstoreloop", &F, newBB);
80239310Sdim
81239310Sdim  origBB->getTerminator()->setSuccessor(0, loopBB);
82239310Sdim  IRBuilder<> builder(origBB, origBB->getTerminator());
83239310Sdim
84249423Sdim  unsigned dstAS = dyn_cast<PointerType>(dstAddr->getType())->getAddressSpace();
85239310Sdim
86239310Sdim  // Cast pointer to the type of value getting stored
87249423Sdim  dstAddr =
88249423Sdim      builder.CreateBitCast(dstAddr, PointerType::get(val->getType(), dstAS));
89239310Sdim
90239310Sdim  IRBuilder<> loop(loopBB);
91239310Sdim  PHINode *ind = loop.CreatePHI(len->getType(), 0);
92239310Sdim  ind->addIncoming(ConstantInt::get(len->getType(), 0), origBB);
93239310Sdim
94239310Sdim  loop.CreateStore(val, loop.CreateGEP(dstAddr, ind), false);
95239310Sdim
96239310Sdim  Value *newind = loop.CreateAdd(ind, ConstantInt::get(len->getType(), 1));
97239310Sdim  ind->addIncoming(newind, loopBB);
98239310Sdim
99239310Sdim  loop.CreateCondBr(loop.CreateICmpULT(newind, len), loopBB, newBB);
100239310Sdim}
101239310Sdim
102239310Sdimbool NVPTXLowerAggrCopies::runOnFunction(Function &F) {
103239310Sdim  SmallVector<LoadInst *, 4> aggrLoads;
104239310Sdim  SmallVector<MemTransferInst *, 4> aggrMemcpys;
105239310Sdim  SmallVector<MemSetInst *, 4> aggrMemsets;
106239310Sdim
107243830Sdim  DataLayout *TD = &getAnalysis<DataLayout>();
108239310Sdim  LLVMContext &Context = F.getParent()->getContext();
109239310Sdim
110239310Sdim  //
111239310Sdim  // Collect all the aggrLoads, aggrMemcpys and addrMemsets.
112239310Sdim  //
113239310Sdim  //const BasicBlock *firstBB = &F.front();  // first BB in F
114239310Sdim  for (Function::iterator BI = F.begin(), BE = F.end(); BI != BE; ++BI) {
115239310Sdim    //BasicBlock *bb = BI;
116239310Sdim    for (BasicBlock::iterator II = BI->begin(), IE = BI->end(); II != IE;
117249423Sdim         ++II) {
118249423Sdim      if (LoadInst *load = dyn_cast<LoadInst>(II)) {
119239310Sdim
120249423Sdim        if (load->hasOneUse() == false)
121249423Sdim          continue;
122239310Sdim
123249423Sdim        if (TD->getTypeStoreSize(load->getType()) < MaxAggrCopySize)
124249423Sdim          continue;
125239310Sdim
126239310Sdim        User *use = *(load->use_begin());
127249423Sdim        if (StoreInst *store = dyn_cast<StoreInst>(use)) {
128239310Sdim          if (store->getOperand(0) != load) //getValueOperand
129249423Sdim            continue;
130239310Sdim          aggrLoads.push_back(load);
131239310Sdim        }
132249423Sdim      } else if (MemTransferInst *intr = dyn_cast<MemTransferInst>(II)) {
133239310Sdim        Value *len = intr->getLength();
134239310Sdim        // If the number of elements being copied is greater
135239310Sdim        // than MaxAggrCopySize, lower it to a loop
136249423Sdim        if (ConstantInt *len_int = dyn_cast<ConstantInt>(len)) {
137239310Sdim          if (len_int->getZExtValue() >= MaxAggrCopySize) {
138239310Sdim            aggrMemcpys.push_back(intr);
139239310Sdim          }
140239310Sdim        } else {
141239310Sdim          // turn variable length memcpy/memmov into loop
142239310Sdim          aggrMemcpys.push_back(intr);
143239310Sdim        }
144249423Sdim      } else if (MemSetInst *memsetintr = dyn_cast<MemSetInst>(II)) {
145239310Sdim        Value *len = memsetintr->getLength();
146249423Sdim        if (ConstantInt *len_int = dyn_cast<ConstantInt>(len)) {
147239310Sdim          if (len_int->getZExtValue() >= MaxAggrCopySize) {
148239310Sdim            aggrMemsets.push_back(memsetintr);
149239310Sdim          }
150239310Sdim        } else {
151239310Sdim          // turn variable length memset into loop
152239310Sdim          aggrMemsets.push_back(memsetintr);
153239310Sdim        }
154239310Sdim      }
155239310Sdim    }
156239310Sdim  }
157249423Sdim  if ((aggrLoads.size() == 0) && (aggrMemcpys.size() == 0) &&
158249423Sdim      (aggrMemsets.size() == 0))
159249423Sdim    return false;
160239310Sdim
161239310Sdim  //
162239310Sdim  // Do the transformation of an aggr load/copy/set to a loop
163239310Sdim  //
164239310Sdim  for (unsigned i = 0, e = aggrLoads.size(); i != e; ++i) {
165239310Sdim    LoadInst *load = aggrLoads[i];
166239310Sdim    StoreInst *store = dyn_cast<StoreInst>(*load->use_begin());
167239310Sdim    Value *srcAddr = load->getOperand(0);
168239310Sdim    Value *dstAddr = store->getOperand(1);
169239310Sdim    unsigned numLoads = TD->getTypeStoreSize(load->getType());
170239310Sdim    Value *len = ConstantInt::get(Type::getInt32Ty(Context), numLoads);
171239310Sdim
172239310Sdim    convertTransferToLoop(store, srcAddr, dstAddr, len, load->isVolatile(),
173239310Sdim                          store->isVolatile(), Context, F);
174239310Sdim
175239310Sdim    store->eraseFromParent();
176239310Sdim    load->eraseFromParent();
177239310Sdim  }
178239310Sdim
179239310Sdim  for (unsigned i = 0, e = aggrMemcpys.size(); i != e; ++i) {
180239310Sdim    MemTransferInst *cpy = aggrMemcpys[i];
181239310Sdim    Value *len = cpy->getLength();
182239310Sdim    // llvm 2.7 version of memcpy does not have volatile
183239310Sdim    // operand yet. So always making it non-volatile
184239310Sdim    // optimistically, so that we don't see unnecessary
185239310Sdim    // st.volatile in ptx
186239310Sdim    convertTransferToLoop(cpy, cpy->getSource(), cpy->getDest(), len, false,
187239310Sdim                          false, Context, F);
188239310Sdim    cpy->eraseFromParent();
189239310Sdim  }
190239310Sdim
191239310Sdim  for (unsigned i = 0, e = aggrMemsets.size(); i != e; ++i) {
192239310Sdim    MemSetInst *memsetinst = aggrMemsets[i];
193239310Sdim    Value *len = memsetinst->getLength();
194239310Sdim    Value *val = memsetinst->getValue();
195239310Sdim    convertMemSetToLoop(memsetinst, memsetinst->getDest(), len, val, Context,
196239310Sdim                        F);
197239310Sdim    memsetinst->eraseFromParent();
198239310Sdim  }
199239310Sdim
200239310Sdim  return true;
201239310Sdim}
202239310Sdim
203239310SdimFunctionPass *llvm::createLowerAggrCopies() {
204239310Sdim  return new NVPTXLowerAggrCopies();
205239310Sdim}
206