CGGPUBuiltin.cpp revision 353358
197403Sobrien//===------ CGGPUBuiltin.cpp - Codegen for GPU builtins -------------------===//
297403Sobrien//
3132720Skan// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
497403Sobrien// See https://llvm.org/LICENSE.txt for license information.
597403Sobrien// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
697403Sobrien//
797403Sobrien//===----------------------------------------------------------------------===//
897403Sobrien//
997403Sobrien// Generates code for built-in GPU calls which are not runtime-specific.
1097403Sobrien// (Runtime-specific codegen lives in programming model specific files.)
1197403Sobrien//
1297403Sobrien//===----------------------------------------------------------------------===//
1397403Sobrien
1497403Sobrien#include "CodeGenFunction.h"
1597403Sobrien#include "clang/Basic/Builtins.h"
1697403Sobrien#include "llvm/IR/DataLayout.h"
1797403Sobrien#include "llvm/IR/Instruction.h"
1897403Sobrien#include "llvm/Support/MathExtras.h"
1997403Sobrien
2097403Sobrienusing namespace clang;
2197403Sobrienusing namespace CodeGen;
2297403Sobrien
2397403Sobrienstatic llvm::Function *GetVprintfDeclaration(llvm::Module &M) {
2497403Sobrien  llvm::Type *ArgTypes[] = {llvm::Type::getInt8PtrTy(M.getContext()),
2597403Sobrien                            llvm::Type::getInt8PtrTy(M.getContext())};
2697403Sobrien  llvm::FunctionType *VprintfFuncType = llvm::FunctionType::get(
2797403Sobrien      llvm::Type::getInt32Ty(M.getContext()), ArgTypes, false);
2897403Sobrien
2997403Sobrien  if (auto* F = M.getFunction("vprintf")) {
3097403Sobrien    // Our CUDA system header declares vprintf with the right signature, so
3197403Sobrien    // nobody else should have been able to declare vprintf with a bogus
3297403Sobrien    // signature.
3397403Sobrien    assert(F->getFunctionType() == VprintfFuncType);
3497403Sobrien    return F;
3597403Sobrien  }
3697403Sobrien
3797403Sobrien  // vprintf doesn't already exist; create a declaration and insert it into the
3897403Sobrien  // module.
3997403Sobrien  return llvm::Function::Create(
40132720Skan      VprintfFuncType, llvm::GlobalVariable::ExternalLinkage, "vprintf", &M);
41132720Skan}
4297403Sobrien
4397403Sobrien// Transforms a call to printf into a call to the NVPTX vprintf syscall (which
4497403Sobrien// isn't particularly special; it's invoked just like a regular function).
4597403Sobrien// vprintf takes two args: A format string, and a pointer to a buffer containing
4697403Sobrien// the varargs.
4797403Sobrien//
4897403Sobrien// For example, the call
4997403Sobrien//
5097403Sobrien//   printf("format string", arg1, arg2, arg3);
5197403Sobrien//
52117397Skan// is converted into something resembling
53117397Skan//
54117397Skan//   struct Tmp {
55117397Skan//     Arg1 a1;
56117397Skan//     Arg2 a2;
5797403Sobrien//     Arg3 a3;
5897403Sobrien//   };
59132720Skan//   char* buf = alloca(sizeof(Tmp));
6097403Sobrien//   *(Tmp*)buf = {a1, a2, a3};
6197403Sobrien//   vprintf("format string", buf);
62117397Skan//
63117397Skan// buf is aligned to the max of {alignof(Arg1), ...}.  Furthermore, each of the
64117397Skan// args is itself aligned to its preferred alignment.
65117397Skan//
66117397Skan// Note that by the time this function runs, E's args have already undergone the
67117397Skan// standard C vararg promotion (short -> int, float -> double, etc.).
68117397SkanRValue
69117397SkanCodeGenFunction::EmitNVPTXDevicePrintfCallExpr(const CallExpr *E,
70117397Skan                                               ReturnValueSlot ReturnValue) {
71117397Skan  assert(getTarget().getTriple().isNVPTX());
72117397Skan  assert(E->getBuiltinCallee() == Builtin::BIprintf);
73117397Skan  assert(E->getNumArgs() >= 1); // printf always has at least one arg.
74117397Skan
75117397Skan  const llvm::DataLayout &DL = CGM.getDataLayout();
76117397Skan  llvm::LLVMContext &Ctx = CGM.getLLVMContext();
77117397Skan
78117397Skan  CallArgList Args;
79117397Skan  EmitCallArgs(Args,
80117397Skan               E->getDirectCallee()->getType()->getAs<FunctionProtoType>(),
81117397Skan               E->arguments(), E->getDirectCallee(),
82117397Skan               /* ParamsToSkip = */ 0);
83117397Skan
84117397Skan  // We don't know how to emit non-scalar varargs.
85117397Skan  if (std::any_of(Args.begin() + 1, Args.end(), [&](const CallArg &A) {
86117397Skan        return !A.getRValue(*this).isScalar();
87117397Skan      })) {
88117397Skan    CGM.ErrorUnsupported(E, "non-scalar arg to printf");
89117397Skan    return RValue::get(llvm::ConstantInt::get(IntTy, 0));
90117397Skan  }
91117397Skan
92117397Skan  // Construct and fill the args buffer that we'll pass to vprintf.
93117397Skan  llvm::Value *BufferPtr;
94117397Skan  if (Args.size() <= 1) {
95117397Skan    // If there are no args, pass a null pointer to vprintf.
96117397Skan    BufferPtr = llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(Ctx));
97117397Skan  } else {
98117397Skan    llvm::SmallVector<llvm::Type *, 8> ArgTypes;
99117397Skan    for (unsigned I = 1, NumArgs = Args.size(); I < NumArgs; ++I)
100117397Skan      ArgTypes.push_back(Args[I].getRValue(*this).getScalarVal()->getType());
101117397Skan
102117397Skan    // Using llvm::StructType is correct only because printf doesn't accept
103117397Skan    // aggregates.  If we had to handle aggregates here, we'd have to manually
104117397Skan    // compute the offsets within the alloca -- we wouldn't be able to assume
105117397Skan    // that the alignment of the llvm type was the same as the alignment of the
106117397Skan    // clang type.
107117397Skan    llvm::Type *AllocaTy = llvm::StructType::create(ArgTypes, "printf_args");
108117397Skan    llvm::Value *Alloca = CreateTempAlloca(AllocaTy);
109117397Skan
110117397Skan    for (unsigned I = 1, NumArgs = Args.size(); I < NumArgs; ++I) {
111117397Skan      llvm::Value *P = Builder.CreateStructGEP(AllocaTy, Alloca, I - 1);
112117397Skan      llvm::Value *Arg = Args[I].getRValue(*this).getScalarVal();
113117397Skan      Builder.CreateAlignedStore(Arg, P, DL.getPrefTypeAlignment(Arg->getType()));
114117397Skan    }
115117397Skan    BufferPtr = Builder.CreatePointerCast(Alloca, llvm::Type::getInt8PtrTy(Ctx));
116117397Skan  }
117117397Skan
118117397Skan  // Invoke vprintf and return.
119117397Skan  llvm::Function* VprintfFunc = GetVprintfDeclaration(CGM.getModule());
120117397Skan  return RValue::get(Builder.CreateCall(
121117397Skan      VprintfFunc, {Args[0].getRValue(*this).getScalarVal(), BufferPtr}));
12297403Sobrien}
12397403Sobrien