1249259Sdim//===- NVVMReflect.cpp - NVVM Emulate conditional compilation -------------===// 2249259Sdim// 3249259Sdim// The LLVM Compiler Infrastructure 4249259Sdim// 5249259Sdim// This file is distributed under the University of Illinois Open Source 6249259Sdim// License. See LICENSE.TXT for details. 7249259Sdim// 8249259Sdim//===----------------------------------------------------------------------===// 9249259Sdim// 10249259Sdim// This pass replaces occurences of __nvvm_reflect("string") with an 11249259Sdim// integer based on -nvvm-reflect-list string=<int> option given to this pass. 12249259Sdim// If an undefined string value is seen in a call to __nvvm_reflect("string"), 13249259Sdim// a default value of 0 will be used. 14249259Sdim// 15249259Sdim//===----------------------------------------------------------------------===// 16249259Sdim 17251662Sdim#include "NVPTX.h" 18249259Sdim#include "llvm/ADT/DenseMap.h" 19249259Sdim#include "llvm/ADT/SmallVector.h" 20249259Sdim#include "llvm/ADT/StringMap.h" 21249259Sdim#include "llvm/Pass.h" 22249259Sdim#include "llvm/IR/Function.h" 23249259Sdim#include "llvm/IR/Module.h" 24249259Sdim#include "llvm/IR/Type.h" 25249259Sdim#include "llvm/IR/DerivedTypes.h" 26249259Sdim#include "llvm/IR/Instructions.h" 27249259Sdim#include "llvm/IR/Constants.h" 28249259Sdim#include "llvm/Support/CommandLine.h" 29249259Sdim#include "llvm/Support/Debug.h" 30249259Sdim#include "llvm/Support/raw_os_ostream.h" 31249259Sdim#include "llvm/Transforms/Scalar.h" 32249259Sdim#include <map> 33249259Sdim#include <sstream> 34249259Sdim#include <string> 35249259Sdim#include <vector> 36249259Sdim 37249259Sdim#define NVVM_REFLECT_FUNCTION "__nvvm_reflect" 38249259Sdim 39249259Sdimusing namespace llvm; 40249259Sdim 41249259Sdimnamespace llvm { void initializeNVVMReflectPass(PassRegistry &); } 42249259Sdim 43249259Sdimnamespace { 44251662Sdimclass NVVMReflect : public ModulePass { 45249259Sdimprivate: 46249259Sdim StringMap<int> VarMap; 47249259Sdim typedef DenseMap<std::string, int>::iterator VarMapIter; 48249259Sdim Function *ReflectFunction; 49249259Sdim 50249259Sdimpublic: 51249259Sdim static char ID; 52251662Sdim NVVMReflect() : ModulePass(ID), ReflectFunction(0) { 53251662Sdim initializeNVVMReflectPass(*PassRegistry::getPassRegistry()); 54249259Sdim VarMap.clear(); 55249259Sdim } 56249259Sdim 57251662Sdim NVVMReflect(const StringMap<int> &Mapping) 58251662Sdim : ModulePass(ID), ReflectFunction(0) { 59251662Sdim initializeNVVMReflectPass(*PassRegistry::getPassRegistry()); 60251662Sdim for (StringMap<int>::const_iterator I = Mapping.begin(), E = Mapping.end(); 61251662Sdim I != E; ++I) { 62251662Sdim VarMap[(*I).getKey()] = (*I).getValue(); 63251662Sdim } 64251662Sdim } 65251662Sdim 66249259Sdim void getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); } 67249259Sdim virtual bool runOnModule(Module &); 68249259Sdim 69249259Sdim void setVarMap(); 70249259Sdim}; 71249259Sdim} 72249259Sdim 73251662SdimModulePass *llvm::createNVVMReflectPass() { 74251662Sdim return new NVVMReflect(); 75251662Sdim} 76251662Sdim 77251662SdimModulePass *llvm::createNVVMReflectPass(const StringMap<int>& Mapping) { 78251662Sdim return new NVVMReflect(Mapping); 79251662Sdim} 80251662Sdim 81249259Sdimstatic cl::opt<bool> 82263508SdimNVVMReflectEnabled("nvvm-reflect-enable", cl::init(true), cl::Hidden, 83249259Sdim cl::desc("NVVM reflection, enabled by default")); 84249259Sdim 85249259Sdimchar NVVMReflect::ID = 0; 86249259SdimINITIALIZE_PASS(NVVMReflect, "nvvm-reflect", 87249259Sdim "Replace occurences of __nvvm_reflect() calls with 0/1", false, 88249259Sdim false) 89249259Sdim 90249259Sdimstatic cl::list<std::string> 91263508SdimReflectList("nvvm-reflect-list", cl::value_desc("name=<int>"), cl::Hidden, 92249259Sdim cl::desc("A list of string=num assignments"), 93249259Sdim cl::ValueRequired); 94249259Sdim 95249259Sdim/// The command line can look as follows : 96249259Sdim/// -nvvm-reflect-list a=1,b=2 -nvvm-reflect-list c=3,d=0 -R e=2 97249259Sdim/// The strings "a=1,b=2", "c=3,d=0", "e=2" are available in the 98249259Sdim/// ReflectList vector. First, each of ReflectList[i] is 'split' 99249259Sdim/// using "," as the delimiter. Then each of this part is split 100249259Sdim/// using "=" as the delimiter. 101249259Sdimvoid NVVMReflect::setVarMap() { 102249259Sdim for (unsigned i = 0, e = ReflectList.size(); i != e; ++i) { 103249259Sdim DEBUG(dbgs() << "Option : " << ReflectList[i] << "\n"); 104249259Sdim SmallVector<StringRef, 4> NameValList; 105249259Sdim StringRef(ReflectList[i]).split(NameValList, ","); 106249259Sdim for (unsigned j = 0, ej = NameValList.size(); j != ej; ++j) { 107249259Sdim SmallVector<StringRef, 2> NameValPair; 108249259Sdim NameValList[j].split(NameValPair, "="); 109249259Sdim assert(NameValPair.size() == 2 && "name=val expected"); 110249259Sdim std::stringstream ValStream(NameValPair[1]); 111249259Sdim int Val; 112249259Sdim ValStream >> Val; 113249259Sdim assert((!(ValStream.fail())) && "integer value expected"); 114249259Sdim VarMap[NameValPair[0]] = Val; 115249259Sdim } 116249259Sdim } 117249259Sdim} 118249259Sdim 119249259Sdimbool NVVMReflect::runOnModule(Module &M) { 120249259Sdim if (!NVVMReflectEnabled) 121249259Sdim return false; 122249259Sdim 123249259Sdim setVarMap(); 124249259Sdim 125249259Sdim ReflectFunction = M.getFunction(NVVM_REFLECT_FUNCTION); 126249259Sdim 127249259Sdim // If reflect function is not used, then there will be 128249259Sdim // no entry in the module. 129249259Sdim if (ReflectFunction == 0) 130249259Sdim return false; 131249259Sdim 132249259Sdim // Validate _reflect function 133249259Sdim assert(ReflectFunction->isDeclaration() && 134249259Sdim "_reflect function should not have a body"); 135249259Sdim assert(ReflectFunction->getReturnType()->isIntegerTy() && 136249259Sdim "_reflect's return type should be integer"); 137249259Sdim 138249259Sdim std::vector<Instruction *> ToRemove; 139249259Sdim 140249259Sdim // Go through the uses of ReflectFunction in this Function. 141249259Sdim // Each of them should a CallInst with a ConstantArray argument. 142249259Sdim // First validate that. If the c-string corresponding to the 143249259Sdim // ConstantArray can be found successfully, see if it can be 144249259Sdim // found in VarMap. If so, replace the uses of CallInst with the 145249259Sdim // value found in VarMap. If not, replace the use with value 0. 146249259Sdim for (Value::use_iterator I = ReflectFunction->use_begin(), 147249259Sdim E = ReflectFunction->use_end(); 148249259Sdim I != E; ++I) { 149249259Sdim assert(isa<CallInst>(*I) && "Only a call instruction can use _reflect"); 150249259Sdim CallInst *Reflect = cast<CallInst>(*I); 151249259Sdim 152249259Sdim assert((Reflect->getNumOperands() == 2) && 153249259Sdim "Only one operand expect for _reflect function"); 154249259Sdim // In cuda, we will have an extra constant-to-generic conversion of 155249259Sdim // the string. 156249259Sdim const Value *conv = Reflect->getArgOperand(0); 157249259Sdim assert(isa<CallInst>(conv) && "Expected a const-to-gen conversion"); 158249259Sdim const CallInst *ConvCall = cast<CallInst>(conv); 159249259Sdim const Value *str = ConvCall->getArgOperand(0); 160249259Sdim assert(isa<ConstantExpr>(str) && 161249259Sdim "Format of _reflect function not recognized"); 162249259Sdim const ConstantExpr *GEP = cast<ConstantExpr>(str); 163249259Sdim 164249259Sdim const Value *Sym = GEP->getOperand(0); 165249259Sdim assert(isa<Constant>(Sym) && "Format of _reflect function not recognized"); 166249259Sdim 167249259Sdim const Constant *SymStr = cast<Constant>(Sym); 168249259Sdim 169249259Sdim assert(isa<ConstantDataSequential>(SymStr->getOperand(0)) && 170249259Sdim "Format of _reflect function not recognized"); 171249259Sdim 172249259Sdim assert(cast<ConstantDataSequential>(SymStr->getOperand(0))->isCString() && 173249259Sdim "Format of _reflect function not recognized"); 174249259Sdim 175249259Sdim std::string ReflectArg = 176249259Sdim cast<ConstantDataSequential>(SymStr->getOperand(0))->getAsString(); 177249259Sdim 178249259Sdim ReflectArg = ReflectArg.substr(0, ReflectArg.size() - 1); 179249259Sdim DEBUG(dbgs() << "Arg of _reflect : " << ReflectArg << "\n"); 180249259Sdim 181249259Sdim int ReflectVal = 0; // The default value is 0 182249259Sdim if (VarMap.find(ReflectArg) != VarMap.end()) { 183249259Sdim ReflectVal = VarMap[ReflectArg]; 184249259Sdim } 185249259Sdim Reflect->replaceAllUsesWith( 186249259Sdim ConstantInt::get(Reflect->getType(), ReflectVal)); 187249259Sdim ToRemove.push_back(Reflect); 188249259Sdim } 189249259Sdim if (ToRemove.size() == 0) 190249259Sdim return false; 191249259Sdim 192249259Sdim for (unsigned i = 0, e = ToRemove.size(); i != e; ++i) 193249259Sdim ToRemove[i]->eraseFromParent(); 194249259Sdim return true; 195249259Sdim} 196