CGGPUBuiltin.cpp revision 353358
197403Sobrien//===------ CGGPUBuiltin.cpp - Codegen for GPU builtins -------------------===// 297403Sobrien// 3132720Skan// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 497403Sobrien// See https://llvm.org/LICENSE.txt for license information. 597403Sobrien// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 697403Sobrien// 797403Sobrien//===----------------------------------------------------------------------===// 897403Sobrien// 997403Sobrien// Generates code for built-in GPU calls which are not runtime-specific. 1097403Sobrien// (Runtime-specific codegen lives in programming model specific files.) 1197403Sobrien// 1297403Sobrien//===----------------------------------------------------------------------===// 1397403Sobrien 1497403Sobrien#include "CodeGenFunction.h" 1597403Sobrien#include "clang/Basic/Builtins.h" 1697403Sobrien#include "llvm/IR/DataLayout.h" 1797403Sobrien#include "llvm/IR/Instruction.h" 1897403Sobrien#include "llvm/Support/MathExtras.h" 1997403Sobrien 2097403Sobrienusing namespace clang; 2197403Sobrienusing namespace CodeGen; 2297403Sobrien 2397403Sobrienstatic llvm::Function *GetVprintfDeclaration(llvm::Module &M) { 2497403Sobrien llvm::Type *ArgTypes[] = {llvm::Type::getInt8PtrTy(M.getContext()), 2597403Sobrien llvm::Type::getInt8PtrTy(M.getContext())}; 2697403Sobrien llvm::FunctionType *VprintfFuncType = llvm::FunctionType::get( 2797403Sobrien llvm::Type::getInt32Ty(M.getContext()), ArgTypes, false); 2897403Sobrien 2997403Sobrien if (auto* F = M.getFunction("vprintf")) { 3097403Sobrien // Our CUDA system header declares vprintf with the right signature, so 3197403Sobrien // nobody else should have been able to declare vprintf with a bogus 3297403Sobrien // signature. 3397403Sobrien assert(F->getFunctionType() == VprintfFuncType); 3497403Sobrien return F; 3597403Sobrien } 3697403Sobrien 3797403Sobrien // vprintf doesn't already exist; create a declaration and insert it into the 3897403Sobrien // module. 3997403Sobrien return llvm::Function::Create( 40132720Skan VprintfFuncType, llvm::GlobalVariable::ExternalLinkage, "vprintf", &M); 41132720Skan} 4297403Sobrien 4397403Sobrien// Transforms a call to printf into a call to the NVPTX vprintf syscall (which 4497403Sobrien// isn't particularly special; it's invoked just like a regular function). 4597403Sobrien// vprintf takes two args: A format string, and a pointer to a buffer containing 4697403Sobrien// the varargs. 4797403Sobrien// 4897403Sobrien// For example, the call 4997403Sobrien// 5097403Sobrien// printf("format string", arg1, arg2, arg3); 5197403Sobrien// 52117397Skan// is converted into something resembling 53117397Skan// 54117397Skan// struct Tmp { 55117397Skan// Arg1 a1; 56117397Skan// Arg2 a2; 5797403Sobrien// Arg3 a3; 5897403Sobrien// }; 59132720Skan// char* buf = alloca(sizeof(Tmp)); 6097403Sobrien// *(Tmp*)buf = {a1, a2, a3}; 6197403Sobrien// vprintf("format string", buf); 62117397Skan// 63117397Skan// buf is aligned to the max of {alignof(Arg1), ...}. Furthermore, each of the 64117397Skan// args is itself aligned to its preferred alignment. 65117397Skan// 66117397Skan// Note that by the time this function runs, E's args have already undergone the 67117397Skan// standard C vararg promotion (short -> int, float -> double, etc.). 68117397SkanRValue 69117397SkanCodeGenFunction::EmitNVPTXDevicePrintfCallExpr(const CallExpr *E, 70117397Skan ReturnValueSlot ReturnValue) { 71117397Skan assert(getTarget().getTriple().isNVPTX()); 72117397Skan assert(E->getBuiltinCallee() == Builtin::BIprintf); 73117397Skan assert(E->getNumArgs() >= 1); // printf always has at least one arg. 74117397Skan 75117397Skan const llvm::DataLayout &DL = CGM.getDataLayout(); 76117397Skan llvm::LLVMContext &Ctx = CGM.getLLVMContext(); 77117397Skan 78117397Skan CallArgList Args; 79117397Skan EmitCallArgs(Args, 80117397Skan E->getDirectCallee()->getType()->getAs<FunctionProtoType>(), 81117397Skan E->arguments(), E->getDirectCallee(), 82117397Skan /* ParamsToSkip = */ 0); 83117397Skan 84117397Skan // We don't know how to emit non-scalar varargs. 85117397Skan if (std::any_of(Args.begin() + 1, Args.end(), [&](const CallArg &A) { 86117397Skan return !A.getRValue(*this).isScalar(); 87117397Skan })) { 88117397Skan CGM.ErrorUnsupported(E, "non-scalar arg to printf"); 89117397Skan return RValue::get(llvm::ConstantInt::get(IntTy, 0)); 90117397Skan } 91117397Skan 92117397Skan // Construct and fill the args buffer that we'll pass to vprintf. 93117397Skan llvm::Value *BufferPtr; 94117397Skan if (Args.size() <= 1) { 95117397Skan // If there are no args, pass a null pointer to vprintf. 96117397Skan BufferPtr = llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(Ctx)); 97117397Skan } else { 98117397Skan llvm::SmallVector<llvm::Type *, 8> ArgTypes; 99117397Skan for (unsigned I = 1, NumArgs = Args.size(); I < NumArgs; ++I) 100117397Skan ArgTypes.push_back(Args[I].getRValue(*this).getScalarVal()->getType()); 101117397Skan 102117397Skan // Using llvm::StructType is correct only because printf doesn't accept 103117397Skan // aggregates. If we had to handle aggregates here, we'd have to manually 104117397Skan // compute the offsets within the alloca -- we wouldn't be able to assume 105117397Skan // that the alignment of the llvm type was the same as the alignment of the 106117397Skan // clang type. 107117397Skan llvm::Type *AllocaTy = llvm::StructType::create(ArgTypes, "printf_args"); 108117397Skan llvm::Value *Alloca = CreateTempAlloca(AllocaTy); 109117397Skan 110117397Skan for (unsigned I = 1, NumArgs = Args.size(); I < NumArgs; ++I) { 111117397Skan llvm::Value *P = Builder.CreateStructGEP(AllocaTy, Alloca, I - 1); 112117397Skan llvm::Value *Arg = Args[I].getRValue(*this).getScalarVal(); 113117397Skan Builder.CreateAlignedStore(Arg, P, DL.getPrefTypeAlignment(Arg->getType())); 114117397Skan } 115117397Skan BufferPtr = Builder.CreatePointerCast(Alloca, llvm::Type::getInt8PtrTy(Ctx)); 116117397Skan } 117117397Skan 118117397Skan // Invoke vprintf and return. 119117397Skan llvm::Function* VprintfFunc = GetVprintfDeclaration(CGM.getModule()); 120117397Skan return RValue::get(Builder.CreateCall( 121117397Skan VprintfFunc, {Args[0].getRValue(*this).getScalarVal(), BufferPtr})); 12297403Sobrien} 12397403Sobrien