1//===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// OpenMP specific optimizations:
10//
11// - Deduplication of runtime calls, e.g., omp_get_thread_num.
12// - Replacing globalized device memory with stack memory.
13// - Replacing globalized device memory with shared memory.
14// - Parallel region merging.
15// - Transforming generic-mode device kernels to SPMD mode.
16// - Specializing the state machine for generic-mode device kernels.
17//
18//===----------------------------------------------------------------------===//
19
20#include "llvm/Transforms/IPO/OpenMPOpt.h"
21
22#include "llvm/ADT/EnumeratedArray.h"
23#include "llvm/ADT/PostOrderIterator.h"
24#include "llvm/ADT/SetVector.h"
25#include "llvm/ADT/SmallPtrSet.h"
26#include "llvm/ADT/SmallVector.h"
27#include "llvm/ADT/Statistic.h"
28#include "llvm/ADT/StringExtras.h"
29#include "llvm/ADT/StringRef.h"
30#include "llvm/Analysis/CallGraph.h"
31#include "llvm/Analysis/CallGraphSCCPass.h"
32#include "llvm/Analysis/MemoryLocation.h"
33#include "llvm/Analysis/OptimizationRemarkEmitter.h"
34#include "llvm/Analysis/ValueTracking.h"
35#include "llvm/Frontend/OpenMP/OMPConstants.h"
36#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
37#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
38#include "llvm/IR/Assumptions.h"
39#include "llvm/IR/BasicBlock.h"
40#include "llvm/IR/Constants.h"
41#include "llvm/IR/DiagnosticInfo.h"
42#include "llvm/IR/Dominators.h"
43#include "llvm/IR/Function.h"
44#include "llvm/IR/GlobalValue.h"
45#include "llvm/IR/GlobalVariable.h"
46#include "llvm/IR/InstrTypes.h"
47#include "llvm/IR/Instruction.h"
48#include "llvm/IR/Instructions.h"
49#include "llvm/IR/IntrinsicInst.h"
50#include "llvm/IR/IntrinsicsAMDGPU.h"
51#include "llvm/IR/IntrinsicsNVPTX.h"
52#include "llvm/IR/LLVMContext.h"
53#include "llvm/Support/Casting.h"
54#include "llvm/Support/CommandLine.h"
55#include "llvm/Support/Debug.h"
56#include "llvm/Transforms/IPO/Attributor.h"
57#include "llvm/Transforms/Utils/BasicBlockUtils.h"
58#include "llvm/Transforms/Utils/CallGraphUpdater.h"
59
60#include <algorithm>
61#include <optional>
62#include <string>
63
64using namespace llvm;
65using namespace omp;
66
67#define DEBUG_TYPE "openmp-opt"
68
69static cl::opt<bool> DisableOpenMPOptimizations(
70    "openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."),
71    cl::Hidden, cl::init(false));
72
73static cl::opt<bool> EnableParallelRegionMerging(
74    "openmp-opt-enable-merging",
75    cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
76    cl::init(false));
77
78static cl::opt<bool>
79    DisableInternalization("openmp-opt-disable-internalization",
80                           cl::desc("Disable function internalization."),
81                           cl::Hidden, cl::init(false));
82
83static cl::opt<bool> DeduceICVValues("openmp-deduce-icv-values",
84                                     cl::init(false), cl::Hidden);
85static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
86                                    cl::Hidden);
87static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
88                                        cl::init(false), cl::Hidden);
89
90static cl::opt<bool> HideMemoryTransferLatency(
91    "openmp-hide-memory-transfer-latency",
92    cl::desc("[WIP] Tries to hide the latency of host to device memory"
93             " transfers"),
94    cl::Hidden, cl::init(false));
95
96static cl::opt<bool> DisableOpenMPOptDeglobalization(
97    "openmp-opt-disable-deglobalization",
98    cl::desc("Disable OpenMP optimizations involving deglobalization."),
99    cl::Hidden, cl::init(false));
100
101static cl::opt<bool> DisableOpenMPOptSPMDization(
102    "openmp-opt-disable-spmdization",
103    cl::desc("Disable OpenMP optimizations involving SPMD-ization."),
104    cl::Hidden, cl::init(false));
105
106static cl::opt<bool> DisableOpenMPOptFolding(
107    "openmp-opt-disable-folding",
108    cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden,
109    cl::init(false));
110
111static cl::opt<bool> DisableOpenMPOptStateMachineRewrite(
112    "openmp-opt-disable-state-machine-rewrite",
113    cl::desc("Disable OpenMP optimizations that replace the state machine."),
114    cl::Hidden, cl::init(false));
115
116static cl::opt<bool> DisableOpenMPOptBarrierElimination(
117    "openmp-opt-disable-barrier-elimination",
118    cl::desc("Disable OpenMP optimizations that eliminate barriers."),
119    cl::Hidden, cl::init(false));
120
121static cl::opt<bool> PrintModuleAfterOptimizations(
122    "openmp-opt-print-module-after",
123    cl::desc("Print the current module after OpenMP optimizations."),
124    cl::Hidden, cl::init(false));
125
126static cl::opt<bool> PrintModuleBeforeOptimizations(
127    "openmp-opt-print-module-before",
128    cl::desc("Print the current module before OpenMP optimizations."),
129    cl::Hidden, cl::init(false));
130
131static cl::opt<bool> AlwaysInlineDeviceFunctions(
132    "openmp-opt-inline-device",
133    cl::desc("Inline all applicible functions on the device."), cl::Hidden,
134    cl::init(false));
135
136static cl::opt<bool>
137    EnableVerboseRemarks("openmp-opt-verbose-remarks",
138                         cl::desc("Enables more verbose remarks."), cl::Hidden,
139                         cl::init(false));
140
141static cl::opt<unsigned>
142    SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden,
143                          cl::desc("Maximal number of attributor iterations."),
144                          cl::init(256));
145
146static cl::opt<unsigned>
147    SharedMemoryLimit("openmp-opt-shared-limit", cl::Hidden,
148                      cl::desc("Maximum amount of shared memory to use."),
149                      cl::init(std::numeric_limits<unsigned>::max()));
150
151STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
152          "Number of OpenMP runtime calls deduplicated");
153STATISTIC(NumOpenMPParallelRegionsDeleted,
154          "Number of OpenMP parallel regions deleted");
155STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
156          "Number of OpenMP runtime functions identified");
157STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
158          "Number of OpenMP runtime function uses identified");
159STATISTIC(NumOpenMPTargetRegionKernels,
160          "Number of OpenMP target region entry points (=kernels) identified");
161STATISTIC(NumNonOpenMPTargetRegionKernels,
162          "Number of non-OpenMP target region kernels identified");
163STATISTIC(NumOpenMPTargetRegionKernelsSPMD,
164          "Number of OpenMP target region entry points (=kernels) executed in "
165          "SPMD-mode instead of generic-mode");
166STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
167          "Number of OpenMP target region entry points (=kernels) executed in "
168          "generic-mode without a state machines");
169STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
170          "Number of OpenMP target region entry points (=kernels) executed in "
171          "generic-mode with customized state machines with fallback");
172STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
173          "Number of OpenMP target region entry points (=kernels) executed in "
174          "generic-mode with customized state machines without fallback");
175STATISTIC(
176    NumOpenMPParallelRegionsReplacedInGPUStateMachine,
177    "Number of OpenMP parallel regions replaced with ID in GPU state machines");
178STATISTIC(NumOpenMPParallelRegionsMerged,
179          "Number of OpenMP parallel regions merged");
180STATISTIC(NumBytesMovedToSharedMemory,
181          "Amount of memory pushed to shared memory");
182STATISTIC(NumBarriersEliminated, "Number of redundant barriers eliminated");
183
184#if !defined(NDEBUG)
185static constexpr auto TAG = "[" DEBUG_TYPE "]";
186#endif
187
188namespace KernelInfo {
189
190// struct ConfigurationEnvironmentTy {
191//   uint8_t UseGenericStateMachine;
192//   uint8_t MayUseNestedParallelism;
193//   llvm::omp::OMPTgtExecModeFlags ExecMode;
194//   int32_t MinThreads;
195//   int32_t MaxThreads;
196//   int32_t MinTeams;
197//   int32_t MaxTeams;
198// };
199
200// struct DynamicEnvironmentTy {
201//   uint16_t DebugIndentionLevel;
202// };
203
204// struct KernelEnvironmentTy {
205//   ConfigurationEnvironmentTy Configuration;
206//   IdentTy *Ident;
207//   DynamicEnvironmentTy *DynamicEnv;
208// };
209
210#define KERNEL_ENVIRONMENT_IDX(MEMBER, IDX)                                    \
211  constexpr const unsigned MEMBER##Idx = IDX;
212
213KERNEL_ENVIRONMENT_IDX(Configuration, 0)
214KERNEL_ENVIRONMENT_IDX(Ident, 1)
215
216#undef KERNEL_ENVIRONMENT_IDX
217
218#define KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MEMBER, IDX)                      \
219  constexpr const unsigned MEMBER##Idx = IDX;
220
221KERNEL_ENVIRONMENT_CONFIGURATION_IDX(UseGenericStateMachine, 0)
222KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MayUseNestedParallelism, 1)
223KERNEL_ENVIRONMENT_CONFIGURATION_IDX(ExecMode, 2)
224KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MinThreads, 3)
225KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MaxThreads, 4)
226KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MinTeams, 5)
227KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MaxTeams, 6)
228
229#undef KERNEL_ENVIRONMENT_CONFIGURATION_IDX
230
231#define KERNEL_ENVIRONMENT_GETTER(MEMBER, RETURNTYPE)                          \
232  RETURNTYPE *get##MEMBER##FromKernelEnvironment(ConstantStruct *KernelEnvC) { \
233    return cast<RETURNTYPE>(KernelEnvC->getAggregateElement(MEMBER##Idx));     \
234  }
235
236KERNEL_ENVIRONMENT_GETTER(Ident, Constant)
237KERNEL_ENVIRONMENT_GETTER(Configuration, ConstantStruct)
238
239#undef KERNEL_ENVIRONMENT_GETTER
240
241#define KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MEMBER)                        \
242  ConstantInt *get##MEMBER##FromKernelEnvironment(                             \
243      ConstantStruct *KernelEnvC) {                                            \
244    ConstantStruct *ConfigC =                                                  \
245        getConfigurationFromKernelEnvironment(KernelEnvC);                     \
246    return dyn_cast<ConstantInt>(ConfigC->getAggregateElement(MEMBER##Idx));   \
247  }
248
249KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(UseGenericStateMachine)
250KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MayUseNestedParallelism)
251KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(ExecMode)
252KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MinThreads)
253KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MaxThreads)
254KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MinTeams)
255KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MaxTeams)
256
257#undef KERNEL_ENVIRONMENT_CONFIGURATION_GETTER
258
259GlobalVariable *
260getKernelEnvironementGVFromKernelInitCB(CallBase *KernelInitCB) {
261  constexpr const int InitKernelEnvironmentArgNo = 0;
262  return cast<GlobalVariable>(
263      KernelInitCB->getArgOperand(InitKernelEnvironmentArgNo)
264          ->stripPointerCasts());
265}
266
267ConstantStruct *getKernelEnvironementFromKernelInitCB(CallBase *KernelInitCB) {
268  GlobalVariable *KernelEnvGV =
269      getKernelEnvironementGVFromKernelInitCB(KernelInitCB);
270  return cast<ConstantStruct>(KernelEnvGV->getInitializer());
271}
272} // namespace KernelInfo
273
274namespace {
275
276struct AAHeapToShared;
277
278struct AAICVTracker;
279
280/// OpenMP specific information. For now, stores RFIs and ICVs also needed for
281/// Attributor runs.
282struct OMPInformationCache : public InformationCache {
283  OMPInformationCache(Module &M, AnalysisGetter &AG,
284                      BumpPtrAllocator &Allocator, SetVector<Function *> *CGSCC,
285                      bool OpenMPPostLink)
286      : InformationCache(M, AG, Allocator, CGSCC), OMPBuilder(M),
287        OpenMPPostLink(OpenMPPostLink) {
288
289    OMPBuilder.Config.IsTargetDevice = isOpenMPDevice(OMPBuilder.M);
290    OMPBuilder.initialize();
291    initializeRuntimeFunctions(M);
292    initializeInternalControlVars();
293  }
294
295  /// Generic information that describes an internal control variable.
296  struct InternalControlVarInfo {
297    /// The kind, as described by InternalControlVar enum.
298    InternalControlVar Kind;
299
300    /// The name of the ICV.
301    StringRef Name;
302
303    /// Environment variable associated with this ICV.
304    StringRef EnvVarName;
305
306    /// Initial value kind.
307    ICVInitValue InitKind;
308
309    /// Initial value.
310    ConstantInt *InitValue;
311
312    /// Setter RTL function associated with this ICV.
313    RuntimeFunction Setter;
314
315    /// Getter RTL function associated with this ICV.
316    RuntimeFunction Getter;
317
318    /// RTL Function corresponding to the override clause of this ICV
319    RuntimeFunction Clause;
320  };
321
322  /// Generic information that describes a runtime function
323  struct RuntimeFunctionInfo {
324
325    /// The kind, as described by the RuntimeFunction enum.
326    RuntimeFunction Kind;
327
328    /// The name of the function.
329    StringRef Name;
330
331    /// Flag to indicate a variadic function.
332    bool IsVarArg;
333
334    /// The return type of the function.
335    Type *ReturnType;
336
337    /// The argument types of the function.
338    SmallVector<Type *, 8> ArgumentTypes;
339
340    /// The declaration if available.
341    Function *Declaration = nullptr;
342
343    /// Uses of this runtime function per function containing the use.
344    using UseVector = SmallVector<Use *, 16>;
345
346    /// Clear UsesMap for runtime function.
347    void clearUsesMap() { UsesMap.clear(); }
348
349    /// Boolean conversion that is true if the runtime function was found.
350    operator bool() const { return Declaration; }
351
352    /// Return the vector of uses in function \p F.
353    UseVector &getOrCreateUseVector(Function *F) {
354      std::shared_ptr<UseVector> &UV = UsesMap[F];
355      if (!UV)
356        UV = std::make_shared<UseVector>();
357      return *UV;
358    }
359
360    /// Return the vector of uses in function \p F or `nullptr` if there are
361    /// none.
362    const UseVector *getUseVector(Function &F) const {
363      auto I = UsesMap.find(&F);
364      if (I != UsesMap.end())
365        return I->second.get();
366      return nullptr;
367    }
368
369    /// Return how many functions contain uses of this runtime function.
370    size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
371
372    /// Return the number of arguments (or the minimal number for variadic
373    /// functions).
374    size_t getNumArgs() const { return ArgumentTypes.size(); }
375
376    /// Run the callback \p CB on each use and forget the use if the result is
377    /// true. The callback will be fed the function in which the use was
378    /// encountered as second argument.
379    void foreachUse(SmallVectorImpl<Function *> &SCC,
380                    function_ref<bool(Use &, Function &)> CB) {
381      for (Function *F : SCC)
382        foreachUse(CB, F);
383    }
384
385    /// Run the callback \p CB on each use within the function \p F and forget
386    /// the use if the result is true.
387    void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
388      SmallVector<unsigned, 8> ToBeDeleted;
389      ToBeDeleted.clear();
390
391      unsigned Idx = 0;
392      UseVector &UV = getOrCreateUseVector(F);
393
394      for (Use *U : UV) {
395        if (CB(*U, *F))
396          ToBeDeleted.push_back(Idx);
397        ++Idx;
398      }
399
400      // Remove the to-be-deleted indices in reverse order as prior
401      // modifications will not modify the smaller indices.
402      while (!ToBeDeleted.empty()) {
403        unsigned Idx = ToBeDeleted.pop_back_val();
404        UV[Idx] = UV.back();
405        UV.pop_back();
406      }
407    }
408
409  private:
410    /// Map from functions to all uses of this runtime function contained in
411    /// them.
412    DenseMap<Function *, std::shared_ptr<UseVector>> UsesMap;
413
414  public:
415    /// Iterators for the uses of this runtime function.
416    decltype(UsesMap)::iterator begin() { return UsesMap.begin(); }
417    decltype(UsesMap)::iterator end() { return UsesMap.end(); }
418  };
419
420  /// An OpenMP-IR-Builder instance
421  OpenMPIRBuilder OMPBuilder;
422
423  /// Map from runtime function kind to the runtime function description.
424  EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
425                  RuntimeFunction::OMPRTL___last>
426      RFIs;
427
428  /// Map from function declarations/definitions to their runtime enum type.
429  DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap;
430
431  /// Map from ICV kind to the ICV description.
432  EnumeratedArray<InternalControlVarInfo, InternalControlVar,
433                  InternalControlVar::ICV___last>
434      ICVs;
435
436  /// Helper to initialize all internal control variable information for those
437  /// defined in OMPKinds.def.
438  void initializeInternalControlVars() {
439#define ICV_RT_SET(_Name, RTL)                                                 \
440  {                                                                            \
441    auto &ICV = ICVs[_Name];                                                   \
442    ICV.Setter = RTL;                                                          \
443  }
444#define ICV_RT_GET(Name, RTL)                                                  \
445  {                                                                            \
446    auto &ICV = ICVs[Name];                                                    \
447    ICV.Getter = RTL;                                                          \
448  }
449#define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init)                           \
450  {                                                                            \
451    auto &ICV = ICVs[Enum];                                                    \
452    ICV.Name = _Name;                                                          \
453    ICV.Kind = Enum;                                                           \
454    ICV.InitKind = Init;                                                       \
455    ICV.EnvVarName = _EnvVarName;                                              \
456    switch (ICV.InitKind) {                                                    \
457    case ICV_IMPLEMENTATION_DEFINED:                                           \
458      ICV.InitValue = nullptr;                                                 \
459      break;                                                                   \
460    case ICV_ZERO:                                                             \
461      ICV.InitValue = ConstantInt::get(                                        \
462          Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0);                \
463      break;                                                                   \
464    case ICV_FALSE:                                                            \
465      ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext());    \
466      break;                                                                   \
467    case ICV_LAST:                                                             \
468      break;                                                                   \
469    }                                                                          \
470  }
471#include "llvm/Frontend/OpenMP/OMPKinds.def"
472  }
473
474  /// Returns true if the function declaration \p F matches the runtime
475  /// function types, that is, return type \p RTFRetType, and argument types
476  /// \p RTFArgTypes.
477  static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
478                                  SmallVector<Type *, 8> &RTFArgTypes) {
479    // TODO: We should output information to the user (under debug output
480    //       and via remarks).
481
482    if (!F)
483      return false;
484    if (F->getReturnType() != RTFRetType)
485      return false;
486    if (F->arg_size() != RTFArgTypes.size())
487      return false;
488
489    auto *RTFTyIt = RTFArgTypes.begin();
490    for (Argument &Arg : F->args()) {
491      if (Arg.getType() != *RTFTyIt)
492        return false;
493
494      ++RTFTyIt;
495    }
496
497    return true;
498  }
499
500  // Helper to collect all uses of the declaration in the UsesMap.
501  unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
502    unsigned NumUses = 0;
503    if (!RFI.Declaration)
504      return NumUses;
505    OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
506
507    if (CollectStats) {
508      NumOpenMPRuntimeFunctionsIdentified += 1;
509      NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
510    }
511
512    // TODO: We directly convert uses into proper calls and unknown uses.
513    for (Use &U : RFI.Declaration->uses()) {
514      if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
515        if (!CGSCC || CGSCC->empty() || CGSCC->contains(UserI->getFunction())) {
516          RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
517          ++NumUses;
518        }
519      } else {
520        RFI.getOrCreateUseVector(nullptr).push_back(&U);
521        ++NumUses;
522      }
523    }
524    return NumUses;
525  }
526
527  // Helper function to recollect uses of a runtime function.
528  void recollectUsesForFunction(RuntimeFunction RTF) {
529    auto &RFI = RFIs[RTF];
530    RFI.clearUsesMap();
531    collectUses(RFI, /*CollectStats*/ false);
532  }
533
534  // Helper function to recollect uses of all runtime functions.
535  void recollectUses() {
536    for (int Idx = 0; Idx < RFIs.size(); ++Idx)
537      recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
538  }
539
540  // Helper function to inherit the calling convention of the function callee.
541  void setCallingConvention(FunctionCallee Callee, CallInst *CI) {
542    if (Function *Fn = dyn_cast<Function>(Callee.getCallee()))
543      CI->setCallingConv(Fn->getCallingConv());
544  }
545
546  // Helper function to determine if it's legal to create a call to the runtime
547  // functions.
548  bool runtimeFnsAvailable(ArrayRef<RuntimeFunction> Fns) {
549    // We can always emit calls if we haven't yet linked in the runtime.
550    if (!OpenMPPostLink)
551      return true;
552
553    // Once the runtime has been already been linked in we cannot emit calls to
554    // any undefined functions.
555    for (RuntimeFunction Fn : Fns) {
556      RuntimeFunctionInfo &RFI = RFIs[Fn];
557
558      if (RFI.Declaration && RFI.Declaration->isDeclaration())
559        return false;
560    }
561    return true;
562  }
563
564  /// Helper to initialize all runtime function information for those defined
565  /// in OpenMPKinds.def.
566  void initializeRuntimeFunctions(Module &M) {
567
568    // Helper macros for handling __VA_ARGS__ in OMP_RTL
569#define OMP_TYPE(VarName, ...)                                                 \
570  Type *VarName = OMPBuilder.VarName;                                          \
571  (void)VarName;
572
573#define OMP_ARRAY_TYPE(VarName, ...)                                           \
574  ArrayType *VarName##Ty = OMPBuilder.VarName##Ty;                             \
575  (void)VarName##Ty;                                                           \
576  PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy;                     \
577  (void)VarName##PtrTy;
578
579#define OMP_FUNCTION_TYPE(VarName, ...)                                        \
580  FunctionType *VarName = OMPBuilder.VarName;                                  \
581  (void)VarName;                                                               \
582  PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
583  (void)VarName##Ptr;
584
585#define OMP_STRUCT_TYPE(VarName, ...)                                          \
586  StructType *VarName = OMPBuilder.VarName;                                    \
587  (void)VarName;                                                               \
588  PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
589  (void)VarName##Ptr;
590
591#define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...)                     \
592  {                                                                            \
593    SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__});                           \
594    Function *F = M.getFunction(_Name);                                        \
595    RTLFunctions.insert(F);                                                    \
596    if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) {           \
597      RuntimeFunctionIDMap[F] = _Enum;                                         \
598      auto &RFI = RFIs[_Enum];                                                 \
599      RFI.Kind = _Enum;                                                        \
600      RFI.Name = _Name;                                                        \
601      RFI.IsVarArg = _IsVarArg;                                                \
602      RFI.ReturnType = OMPBuilder._ReturnType;                                 \
603      RFI.ArgumentTypes = std::move(ArgsTypes);                                \
604      RFI.Declaration = F;                                                     \
605      unsigned NumUses = collectUses(RFI);                                     \
606      (void)NumUses;                                                           \
607      LLVM_DEBUG({                                                             \
608        dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not")           \
609               << " found\n";                                                  \
610        if (RFI.Declaration)                                                   \
611          dbgs() << TAG << "-> got " << NumUses << " uses in "                 \
612                 << RFI.getNumFunctionsWithUses()                              \
613                 << " different functions.\n";                                 \
614      });                                                                      \
615    }                                                                          \
616  }
617#include "llvm/Frontend/OpenMP/OMPKinds.def"
618
619    // Remove the `noinline` attribute from `__kmpc`, `ompx::` and `omp_`
620    // functions, except if `optnone` is present.
621    if (isOpenMPDevice(M)) {
622      for (Function &F : M) {
623        for (StringRef Prefix : {"__kmpc", "_ZN4ompx", "omp_"})
624          if (F.hasFnAttribute(Attribute::NoInline) &&
625              F.getName().starts_with(Prefix) &&
626              !F.hasFnAttribute(Attribute::OptimizeNone))
627            F.removeFnAttr(Attribute::NoInline);
628      }
629    }
630
631    // TODO: We should attach the attributes defined in OMPKinds.def.
632  }
633
634  /// Collection of known OpenMP runtime functions..
635  DenseSet<const Function *> RTLFunctions;
636
637  /// Indicates if we have already linked in the OpenMP device library.
638  bool OpenMPPostLink = false;
639};
640
641template <typename Ty, bool InsertInvalidates = true>
642struct BooleanStateWithSetVector : public BooleanState {
643  bool contains(const Ty &Elem) const { return Set.contains(Elem); }
644  bool insert(const Ty &Elem) {
645    if (InsertInvalidates)
646      BooleanState::indicatePessimisticFixpoint();
647    return Set.insert(Elem);
648  }
649
650  const Ty &operator[](int Idx) const { return Set[Idx]; }
651  bool operator==(const BooleanStateWithSetVector &RHS) const {
652    return BooleanState::operator==(RHS) && Set == RHS.Set;
653  }
654  bool operator!=(const BooleanStateWithSetVector &RHS) const {
655    return !(*this == RHS);
656  }
657
658  bool empty() const { return Set.empty(); }
659  size_t size() const { return Set.size(); }
660
661  /// "Clamp" this state with \p RHS.
662  BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
663    BooleanState::operator^=(RHS);
664    Set.insert(RHS.Set.begin(), RHS.Set.end());
665    return *this;
666  }
667
668private:
669  /// A set to keep track of elements.
670  SetVector<Ty> Set;
671
672public:
673  typename decltype(Set)::iterator begin() { return Set.begin(); }
674  typename decltype(Set)::iterator end() { return Set.end(); }
675  typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
676  typename decltype(Set)::const_iterator end() const { return Set.end(); }
677};
678
679template <typename Ty, bool InsertInvalidates = true>
680using BooleanStateWithPtrSetVector =
681    BooleanStateWithSetVector<Ty *, InsertInvalidates>;
682
683struct KernelInfoState : AbstractState {
684  /// Flag to track if we reached a fixpoint.
685  bool IsAtFixpoint = false;
686
687  /// The parallel regions (identified by the outlined parallel functions) that
688  /// can be reached from the associated function.
689  BooleanStateWithPtrSetVector<CallBase, /* InsertInvalidates */ false>
690      ReachedKnownParallelRegions;
691
692  /// State to track what parallel region we might reach.
693  BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
694
695  /// State to track if we are in SPMD-mode, assumed or know, and why we decided
696  /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
697  /// false.
698  BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
699
700  /// The __kmpc_target_init call in this kernel, if any. If we find more than
701  /// one we abort as the kernel is malformed.
702  CallBase *KernelInitCB = nullptr;
703
704  /// The constant kernel environement as taken from and passed to
705  /// __kmpc_target_init.
706  ConstantStruct *KernelEnvC = nullptr;
707
708  /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
709  /// one we abort as the kernel is malformed.
710  CallBase *KernelDeinitCB = nullptr;
711
712  /// Flag to indicate if the associated function is a kernel entry.
713  bool IsKernelEntry = false;
714
715  /// State to track what kernel entries can reach the associated function.
716  BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
717
718  /// State to indicate if we can track parallel level of the associated
719  /// function. We will give up tracking if we encounter unknown caller or the
720  /// caller is __kmpc_parallel_51.
721  BooleanStateWithSetVector<uint8_t> ParallelLevels;
722
723  /// Flag that indicates if the kernel has nested Parallelism
724  bool NestedParallelism = false;
725
726  /// Abstract State interface
727  ///{
728
729  KernelInfoState() = default;
730  KernelInfoState(bool BestState) {
731    if (!BestState)
732      indicatePessimisticFixpoint();
733  }
734
735  /// See AbstractState::isValidState(...)
736  bool isValidState() const override { return true; }
737
738  /// See AbstractState::isAtFixpoint(...)
739  bool isAtFixpoint() const override { return IsAtFixpoint; }
740
741  /// See AbstractState::indicatePessimisticFixpoint(...)
742  ChangeStatus indicatePessimisticFixpoint() override {
743    IsAtFixpoint = true;
744    ParallelLevels.indicatePessimisticFixpoint();
745    ReachingKernelEntries.indicatePessimisticFixpoint();
746    SPMDCompatibilityTracker.indicatePessimisticFixpoint();
747    ReachedKnownParallelRegions.indicatePessimisticFixpoint();
748    ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
749    NestedParallelism = true;
750    return ChangeStatus::CHANGED;
751  }
752
753  /// See AbstractState::indicateOptimisticFixpoint(...)
754  ChangeStatus indicateOptimisticFixpoint() override {
755    IsAtFixpoint = true;
756    ParallelLevels.indicateOptimisticFixpoint();
757    ReachingKernelEntries.indicateOptimisticFixpoint();
758    SPMDCompatibilityTracker.indicateOptimisticFixpoint();
759    ReachedKnownParallelRegions.indicateOptimisticFixpoint();
760    ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
761    return ChangeStatus::UNCHANGED;
762  }
763
764  /// Return the assumed state
765  KernelInfoState &getAssumed() { return *this; }
766  const KernelInfoState &getAssumed() const { return *this; }
767
768  bool operator==(const KernelInfoState &RHS) const {
769    if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
770      return false;
771    if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
772      return false;
773    if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
774      return false;
775    if (ReachingKernelEntries != RHS.ReachingKernelEntries)
776      return false;
777    if (ParallelLevels != RHS.ParallelLevels)
778      return false;
779    if (NestedParallelism != RHS.NestedParallelism)
780      return false;
781    return true;
782  }
783
784  /// Returns true if this kernel contains any OpenMP parallel regions.
785  bool mayContainParallelRegion() {
786    return !ReachedKnownParallelRegions.empty() ||
787           !ReachedUnknownParallelRegions.empty();
788  }
789
790  /// Return empty set as the best state of potential values.
791  static KernelInfoState getBestState() { return KernelInfoState(true); }
792
793  static KernelInfoState getBestState(KernelInfoState &KIS) {
794    return getBestState();
795  }
796
797  /// Return full set as the worst state of potential values.
798  static KernelInfoState getWorstState() { return KernelInfoState(false); }
799
800  /// "Clamp" this state with \p KIS.
801  KernelInfoState operator^=(const KernelInfoState &KIS) {
802    // Do not merge two different _init and _deinit call sites.
803    if (KIS.KernelInitCB) {
804      if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
805        llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
806                         "assumptions.");
807      KernelInitCB = KIS.KernelInitCB;
808    }
809    if (KIS.KernelDeinitCB) {
810      if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
811        llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
812                         "assumptions.");
813      KernelDeinitCB = KIS.KernelDeinitCB;
814    }
815    if (KIS.KernelEnvC) {
816      if (KernelEnvC && KernelEnvC != KIS.KernelEnvC)
817        llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
818                         "assumptions.");
819      KernelEnvC = KIS.KernelEnvC;
820    }
821    SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
822    ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
823    ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
824    NestedParallelism |= KIS.NestedParallelism;
825    return *this;
826  }
827
828  KernelInfoState operator&=(const KernelInfoState &KIS) {
829    return (*this ^= KIS);
830  }
831
832  ///}
833};
834
835/// Used to map the values physically (in the IR) stored in an offload
836/// array, to a vector in memory.
837struct OffloadArray {
838  /// Physical array (in the IR).
839  AllocaInst *Array = nullptr;
840  /// Mapped values.
841  SmallVector<Value *, 8> StoredValues;
842  /// Last stores made in the offload array.
843  SmallVector<StoreInst *, 8> LastAccesses;
844
845  OffloadArray() = default;
846
847  /// Initializes the OffloadArray with the values stored in \p Array before
848  /// instruction \p Before is reached. Returns false if the initialization
849  /// fails.
850  /// This MUST be used immediately after the construction of the object.
851  bool initialize(AllocaInst &Array, Instruction &Before) {
852    if (!Array.getAllocatedType()->isArrayTy())
853      return false;
854
855    if (!getValues(Array, Before))
856      return false;
857
858    this->Array = &Array;
859    return true;
860  }
861
862  static const unsigned DeviceIDArgNum = 1;
863  static const unsigned BasePtrsArgNum = 3;
864  static const unsigned PtrsArgNum = 4;
865  static const unsigned SizesArgNum = 5;
866
867private:
868  /// Traverses the BasicBlock where \p Array is, collecting the stores made to
869  /// \p Array, leaving StoredValues with the values stored before the
870  /// instruction \p Before is reached.
871  bool getValues(AllocaInst &Array, Instruction &Before) {
872    // Initialize container.
873    const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements();
874    StoredValues.assign(NumValues, nullptr);
875    LastAccesses.assign(NumValues, nullptr);
876
877    // TODO: This assumes the instruction \p Before is in the same
878    //  BasicBlock as Array. Make it general, for any control flow graph.
879    BasicBlock *BB = Array.getParent();
880    if (BB != Before.getParent())
881      return false;
882
883    const DataLayout &DL = Array.getModule()->getDataLayout();
884    const unsigned int PointerSize = DL.getPointerSize();
885
886    for (Instruction &I : *BB) {
887      if (&I == &Before)
888        break;
889
890      if (!isa<StoreInst>(&I))
891        continue;
892
893      auto *S = cast<StoreInst>(&I);
894      int64_t Offset = -1;
895      auto *Dst =
896          GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);
897      if (Dst == &Array) {
898        int64_t Idx = Offset / PointerSize;
899        StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());
900        LastAccesses[Idx] = S;
901      }
902    }
903
904    return isFilled();
905  }
906
907  /// Returns true if all values in StoredValues and
908  /// LastAccesses are not nullptrs.
909  bool isFilled() {
910    const unsigned NumValues = StoredValues.size();
911    for (unsigned I = 0; I < NumValues; ++I) {
912      if (!StoredValues[I] || !LastAccesses[I])
913        return false;
914    }
915
916    return true;
917  }
918};
919
920struct OpenMPOpt {
921
922  using OptimizationRemarkGetter =
923      function_ref<OptimizationRemarkEmitter &(Function *)>;
924
925  OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
926            OptimizationRemarkGetter OREGetter,
927            OMPInformationCache &OMPInfoCache, Attributor &A)
928      : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
929        OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
930
931  /// Check if any remarks are enabled for openmp-opt
932  bool remarksEnabled() {
933    auto &Ctx = M.getContext();
934    return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
935  }
936
937  /// Run all OpenMP optimizations on the underlying SCC.
938  bool run(bool IsModulePass) {
939    if (SCC.empty())
940      return false;
941
942    bool Changed = false;
943
944    LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
945                      << " functions\n");
946
947    if (IsModulePass) {
948      Changed |= runAttributor(IsModulePass);
949
950      // Recollect uses, in case Attributor deleted any.
951      OMPInfoCache.recollectUses();
952
953      // TODO: This should be folded into buildCustomStateMachine.
954      Changed |= rewriteDeviceCodeStateMachine();
955
956      if (remarksEnabled())
957        analysisGlobalization();
958    } else {
959      if (PrintICVValues)
960        printICVs();
961      if (PrintOpenMPKernels)
962        printKernels();
963
964      Changed |= runAttributor(IsModulePass);
965
966      // Recollect uses, in case Attributor deleted any.
967      OMPInfoCache.recollectUses();
968
969      Changed |= deleteParallelRegions();
970
971      if (HideMemoryTransferLatency)
972        Changed |= hideMemTransfersLatency();
973      Changed |= deduplicateRuntimeCalls();
974      if (EnableParallelRegionMerging) {
975        if (mergeParallelRegions()) {
976          deduplicateRuntimeCalls();
977          Changed = true;
978        }
979      }
980    }
981
982    if (OMPInfoCache.OpenMPPostLink)
983      Changed |= removeRuntimeSymbols();
984
985    return Changed;
986  }
987
988  /// Print initial ICV values for testing.
989  /// FIXME: This should be done from the Attributor once it is added.
990  void printICVs() const {
991    InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
992                                 ICV_proc_bind};
993
994    for (Function *F : SCC) {
995      for (auto ICV : ICVs) {
996        auto ICVInfo = OMPInfoCache.ICVs[ICV];
997        auto Remark = [&](OptimizationRemarkAnalysis ORA) {
998          return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
999                     << " Value: "
1000                     << (ICVInfo.InitValue
1001                             ? toString(ICVInfo.InitValue->getValue(), 10, true)
1002                             : "IMPLEMENTATION_DEFINED");
1003        };
1004
1005        emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);
1006      }
1007    }
1008  }
1009
1010  /// Print OpenMP GPU kernels for testing.
1011  void printKernels() const {
1012    for (Function *F : SCC) {
1013      if (!omp::isOpenMPKernel(*F))
1014        continue;
1015
1016      auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1017        return ORA << "OpenMP GPU kernel "
1018                   << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
1019      };
1020
1021      emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark);
1022    }
1023  }
1024
1025  /// Return the call if \p U is a callee use in a regular call. If \p RFI is
1026  /// given it has to be the callee or a nullptr is returned.
1027  static CallInst *getCallIfRegularCall(
1028      Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
1029    CallInst *CI = dyn_cast<CallInst>(U.getUser());
1030    if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
1031        (!RFI ||
1032         (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
1033      return CI;
1034    return nullptr;
1035  }
1036
1037  /// Return the call if \p V is a regular call. If \p RFI is given it has to be
1038  /// the callee or a nullptr is returned.
1039  static CallInst *getCallIfRegularCall(
1040      Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
1041    CallInst *CI = dyn_cast<CallInst>(&V);
1042    if (CI && !CI->hasOperandBundles() &&
1043        (!RFI ||
1044         (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
1045      return CI;
1046    return nullptr;
1047  }
1048
1049private:
1050  /// Merge parallel regions when it is safe.
1051  bool mergeParallelRegions() {
1052    const unsigned CallbackCalleeOperand = 2;
1053    const unsigned CallbackFirstArgOperand = 3;
1054    using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
1055
1056    // Check if there are any __kmpc_fork_call calls to merge.
1057    OMPInformationCache::RuntimeFunctionInfo &RFI =
1058        OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1059
1060    if (!RFI.Declaration)
1061      return false;
1062
1063    // Unmergable calls that prevent merging a parallel region.
1064    OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
1065        OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
1066        OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
1067    };
1068
1069    bool Changed = false;
1070    LoopInfo *LI = nullptr;
1071    DominatorTree *DT = nullptr;
1072
1073    SmallDenseMap<BasicBlock *, SmallPtrSet<Instruction *, 4>> BB2PRMap;
1074
1075    BasicBlock *StartBB = nullptr, *EndBB = nullptr;
1076    auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1077      BasicBlock *CGStartBB = CodeGenIP.getBlock();
1078      BasicBlock *CGEndBB =
1079          SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1080      assert(StartBB != nullptr && "StartBB should not be null");
1081      CGStartBB->getTerminator()->setSuccessor(0, StartBB);
1082      assert(EndBB != nullptr && "EndBB should not be null");
1083      EndBB->getTerminator()->setSuccessor(0, CGEndBB);
1084    };
1085
1086    auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
1087                      Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
1088      ReplacementValue = &Inner;
1089      return CodeGenIP;
1090    };
1091
1092    auto FiniCB = [&](InsertPointTy CodeGenIP) {};
1093
1094    /// Create a sequential execution region within a merged parallel region,
1095    /// encapsulated in a master construct with a barrier for synchronization.
1096    auto CreateSequentialRegion = [&](Function *OuterFn,
1097                                      BasicBlock *OuterPredBB,
1098                                      Instruction *SeqStartI,
1099                                      Instruction *SeqEndI) {
1100      // Isolate the instructions of the sequential region to a separate
1101      // block.
1102      BasicBlock *ParentBB = SeqStartI->getParent();
1103      BasicBlock *SeqEndBB =
1104          SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
1105      BasicBlock *SeqAfterBB =
1106          SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
1107      BasicBlock *SeqStartBB =
1108          SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
1109
1110      assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
1111             "Expected a different CFG");
1112      const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
1113      ParentBB->getTerminator()->eraseFromParent();
1114
1115      auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1116        BasicBlock *CGStartBB = CodeGenIP.getBlock();
1117        BasicBlock *CGEndBB =
1118            SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1119        assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
1120        CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
1121        assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
1122        SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
1123      };
1124      auto FiniCB = [&](InsertPointTy CodeGenIP) {};
1125
1126      // Find outputs from the sequential region to outside users and
1127      // broadcast their values to them.
1128      for (Instruction &I : *SeqStartBB) {
1129        SmallPtrSet<Instruction *, 4> OutsideUsers;
1130        for (User *Usr : I.users()) {
1131          Instruction &UsrI = *cast<Instruction>(Usr);
1132          // Ignore outputs to LT intrinsics, code extraction for the merged
1133          // parallel region will fix them.
1134          if (UsrI.isLifetimeStartOrEnd())
1135            continue;
1136
1137          if (UsrI.getParent() != SeqStartBB)
1138            OutsideUsers.insert(&UsrI);
1139        }
1140
1141        if (OutsideUsers.empty())
1142          continue;
1143
1144        // Emit an alloca in the outer region to store the broadcasted
1145        // value.
1146        const DataLayout &DL = M.getDataLayout();
1147        AllocaInst *AllocaI = new AllocaInst(
1148            I.getType(), DL.getAllocaAddrSpace(), nullptr,
1149            I.getName() + ".seq.output.alloc", &OuterFn->front().front());
1150
1151        // Emit a store instruction in the sequential BB to update the
1152        // value.
1153        new StoreInst(&I, AllocaI, SeqStartBB->getTerminator());
1154
1155        // Emit a load instruction and replace the use of the output value
1156        // with it.
1157        for (Instruction *UsrI : OutsideUsers) {
1158          LoadInst *LoadI = new LoadInst(
1159              I.getType(), AllocaI, I.getName() + ".seq.output.load", UsrI);
1160          UsrI->replaceUsesOfWith(&I, LoadI);
1161        }
1162      }
1163
1164      OpenMPIRBuilder::LocationDescription Loc(
1165          InsertPointTy(ParentBB, ParentBB->end()), DL);
1166      InsertPointTy SeqAfterIP =
1167          OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);
1168
1169      OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel);
1170
1171      BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock());
1172
1173      LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
1174                        << "\n");
1175    };
1176
1177    // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
1178    // contained in BB and only separated by instructions that can be
1179    // redundantly executed in parallel. The block BB is split before the first
1180    // call (in MergableCIs) and after the last so the entire region we merge
1181    // into a single parallel region is contained in a single basic block
1182    // without any other instructions. We use the OpenMPIRBuilder to outline
1183    // that block and call the resulting function via __kmpc_fork_call.
1184    auto Merge = [&](const SmallVectorImpl<CallInst *> &MergableCIs,
1185                     BasicBlock *BB) {
1186      // TODO: Change the interface to allow single CIs expanded, e.g, to
1187      // include an outer loop.
1188      assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");
1189
1190      auto Remark = [&](OptimizationRemark OR) {
1191        OR << "Parallel region merged with parallel region"
1192           << (MergableCIs.size() > 2 ? "s" : "") << " at ";
1193        for (auto *CI : llvm::drop_begin(MergableCIs)) {
1194          OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
1195          if (CI != MergableCIs.back())
1196            OR << ", ";
1197        }
1198        return OR << ".";
1199      };
1200
1201      emitRemark<OptimizationRemark>(MergableCIs.front(), "OMP150", Remark);
1202
1203      Function *OriginalFn = BB->getParent();
1204      LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()
1205                        << " parallel regions in " << OriginalFn->getName()
1206                        << "\n");
1207
1208      // Isolate the calls to merge in a separate block.
1209      EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);
1210      BasicBlock *AfterBB =
1211          SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
1212      StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,
1213                           "omp.par.merged");
1214
1215      assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");
1216      const DebugLoc DL = BB->getTerminator()->getDebugLoc();
1217      BB->getTerminator()->eraseFromParent();
1218
1219      // Create sequential regions for sequential instructions that are
1220      // in-between mergable parallel regions.
1221      for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
1222           It != End; ++It) {
1223        Instruction *ForkCI = *It;
1224        Instruction *NextForkCI = *(It + 1);
1225
1226        // Continue if there are not in-between instructions.
1227        if (ForkCI->getNextNode() == NextForkCI)
1228          continue;
1229
1230        CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
1231                               NextForkCI->getPrevNode());
1232      }
1233
1234      OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
1235                                               DL);
1236      IRBuilder<>::InsertPoint AllocaIP(
1237          &OriginalFn->getEntryBlock(),
1238          OriginalFn->getEntryBlock().getFirstInsertionPt());
1239      // Create the merged parallel region with default proc binding, to
1240      // avoid overriding binding settings, and without explicit cancellation.
1241      InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel(
1242          Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr,
1243          OMP_PROC_BIND_default, /* IsCancellable */ false);
1244      BranchInst::Create(AfterBB, AfterIP.getBlock());
1245
1246      // Perform the actual outlining.
1247      OMPInfoCache.OMPBuilder.finalize(OriginalFn);
1248
1249      Function *OutlinedFn = MergableCIs.front()->getCaller();
1250
1251      // Replace the __kmpc_fork_call calls with direct calls to the outlined
1252      // callbacks.
1253      SmallVector<Value *, 8> Args;
1254      for (auto *CI : MergableCIs) {
1255        Value *Callee = CI->getArgOperand(CallbackCalleeOperand);
1256        FunctionType *FT = OMPInfoCache.OMPBuilder.ParallelTask;
1257        Args.clear();
1258        Args.push_back(OutlinedFn->getArg(0));
1259        Args.push_back(OutlinedFn->getArg(1));
1260        for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1261             ++U)
1262          Args.push_back(CI->getArgOperand(U));
1263
1264        CallInst *NewCI = CallInst::Create(FT, Callee, Args, "", CI);
1265        if (CI->getDebugLoc())
1266          NewCI->setDebugLoc(CI->getDebugLoc());
1267
1268        // Forward parameter attributes from the callback to the callee.
1269        for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1270             ++U)
1271          for (const Attribute &A : CI->getAttributes().getParamAttrs(U))
1272            NewCI->addParamAttr(
1273                U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1274
1275        // Emit an explicit barrier to replace the implicit fork-join barrier.
1276        if (CI != MergableCIs.back()) {
1277          // TODO: Remove barrier if the merged parallel region includes the
1278          // 'nowait' clause.
1279          OMPInfoCache.OMPBuilder.createBarrier(
1280              InsertPointTy(NewCI->getParent(),
1281                            NewCI->getNextNode()->getIterator()),
1282              OMPD_parallel);
1283        }
1284
1285        CI->eraseFromParent();
1286      }
1287
1288      assert(OutlinedFn != OriginalFn && "Outlining failed");
1289      CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
1290      CGUpdater.reanalyzeFunction(*OriginalFn);
1291
1292      NumOpenMPParallelRegionsMerged += MergableCIs.size();
1293
1294      return true;
1295    };
1296
1297    // Helper function that identifes sequences of
1298    // __kmpc_fork_call uses in a basic block.
1299    auto DetectPRsCB = [&](Use &U, Function &F) {
1300      CallInst *CI = getCallIfRegularCall(U, &RFI);
1301      BB2PRMap[CI->getParent()].insert(CI);
1302
1303      return false;
1304    };
1305
1306    BB2PRMap.clear();
1307    RFI.foreachUse(SCC, DetectPRsCB);
1308    SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
1309    // Find mergable parallel regions within a basic block that are
1310    // safe to merge, that is any in-between instructions can safely
1311    // execute in parallel after merging.
1312    // TODO: support merging across basic-blocks.
1313    for (auto &It : BB2PRMap) {
1314      auto &CIs = It.getSecond();
1315      if (CIs.size() < 2)
1316        continue;
1317
1318      BasicBlock *BB = It.getFirst();
1319      SmallVector<CallInst *, 4> MergableCIs;
1320
1321      /// Returns true if the instruction is mergable, false otherwise.
1322      /// A terminator instruction is unmergable by definition since merging
1323      /// works within a BB. Instructions before the mergable region are
1324      /// mergable if they are not calls to OpenMP runtime functions that may
1325      /// set different execution parameters for subsequent parallel regions.
1326      /// Instructions in-between parallel regions are mergable if they are not
1327      /// calls to any non-intrinsic function since that may call a non-mergable
1328      /// OpenMP runtime function.
1329      auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
1330        // We do not merge across BBs, hence return false (unmergable) if the
1331        // instruction is a terminator.
1332        if (I.isTerminator())
1333          return false;
1334
1335        if (!isa<CallInst>(&I))
1336          return true;
1337
1338        CallInst *CI = cast<CallInst>(&I);
1339        if (IsBeforeMergableRegion) {
1340          Function *CalledFunction = CI->getCalledFunction();
1341          if (!CalledFunction)
1342            return false;
1343          // Return false (unmergable) if the call before the parallel
1344          // region calls an explicit affinity (proc_bind) or number of
1345          // threads (num_threads) compiler-generated function. Those settings
1346          // may be incompatible with following parallel regions.
1347          // TODO: ICV tracking to detect compatibility.
1348          for (const auto &RFI : UnmergableCallsInfo) {
1349            if (CalledFunction == RFI.Declaration)
1350              return false;
1351          }
1352        } else {
1353          // Return false (unmergable) if there is a call instruction
1354          // in-between parallel regions when it is not an intrinsic. It
1355          // may call an unmergable OpenMP runtime function in its callpath.
1356          // TODO: Keep track of possible OpenMP calls in the callpath.
1357          if (!isa<IntrinsicInst>(CI))
1358            return false;
1359        }
1360
1361        return true;
1362      };
1363      // Find maximal number of parallel region CIs that are safe to merge.
1364      for (auto It = BB->begin(), End = BB->end(); It != End;) {
1365        Instruction &I = *It;
1366        ++It;
1367
1368        if (CIs.count(&I)) {
1369          MergableCIs.push_back(cast<CallInst>(&I));
1370          continue;
1371        }
1372
1373        // Continue expanding if the instruction is mergable.
1374        if (IsMergable(I, MergableCIs.empty()))
1375          continue;
1376
1377        // Forward the instruction iterator to skip the next parallel region
1378        // since there is an unmergable instruction which can affect it.
1379        for (; It != End; ++It) {
1380          Instruction &SkipI = *It;
1381          if (CIs.count(&SkipI)) {
1382            LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
1383                              << " due to " << I << "\n");
1384            ++It;
1385            break;
1386          }
1387        }
1388
1389        // Store mergable regions found.
1390        if (MergableCIs.size() > 1) {
1391          MergableCIsVector.push_back(MergableCIs);
1392          LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
1393                            << " parallel regions in block " << BB->getName()
1394                            << " of function " << BB->getParent()->getName()
1395                            << "\n";);
1396        }
1397
1398        MergableCIs.clear();
1399      }
1400
1401      if (!MergableCIsVector.empty()) {
1402        Changed = true;
1403
1404        for (auto &MergableCIs : MergableCIsVector)
1405          Merge(MergableCIs, BB);
1406        MergableCIsVector.clear();
1407      }
1408    }
1409
1410    if (Changed) {
1411      /// Re-collect use for fork calls, emitted barrier calls, and
1412      /// any emitted master/end_master calls.
1413      OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1414      OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1415      OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1416      OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1417    }
1418
1419    return Changed;
1420  }
1421
1422  /// Try to delete parallel regions if possible.
1423  bool deleteParallelRegions() {
1424    const unsigned CallbackCalleeOperand = 2;
1425
1426    OMPInformationCache::RuntimeFunctionInfo &RFI =
1427        OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1428
1429    if (!RFI.Declaration)
1430      return false;
1431
1432    bool Changed = false;
1433    auto DeleteCallCB = [&](Use &U, Function &) {
1434      CallInst *CI = getCallIfRegularCall(U);
1435      if (!CI)
1436        return false;
1437      auto *Fn = dyn_cast<Function>(
1438          CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
1439      if (!Fn)
1440        return false;
1441      if (!Fn->onlyReadsMemory())
1442        return false;
1443      if (!Fn->hasFnAttribute(Attribute::WillReturn))
1444        return false;
1445
1446      LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
1447                        << CI->getCaller()->getName() << "\n");
1448
1449      auto Remark = [&](OptimizationRemark OR) {
1450        return OR << "Removing parallel region with no side-effects.";
1451      };
1452      emitRemark<OptimizationRemark>(CI, "OMP160", Remark);
1453
1454      CGUpdater.removeCallSite(*CI);
1455      CI->eraseFromParent();
1456      Changed = true;
1457      ++NumOpenMPParallelRegionsDeleted;
1458      return true;
1459    };
1460
1461    RFI.foreachUse(SCC, DeleteCallCB);
1462
1463    return Changed;
1464  }
1465
1466  /// Try to eliminate runtime calls by reusing existing ones.
1467  bool deduplicateRuntimeCalls() {
1468    bool Changed = false;
1469
1470    RuntimeFunction DeduplicableRuntimeCallIDs[] = {
1471        OMPRTL_omp_get_num_threads,
1472        OMPRTL_omp_in_parallel,
1473        OMPRTL_omp_get_cancellation,
1474        OMPRTL_omp_get_thread_limit,
1475        OMPRTL_omp_get_supported_active_levels,
1476        OMPRTL_omp_get_level,
1477        OMPRTL_omp_get_ancestor_thread_num,
1478        OMPRTL_omp_get_team_size,
1479        OMPRTL_omp_get_active_level,
1480        OMPRTL_omp_in_final,
1481        OMPRTL_omp_get_proc_bind,
1482        OMPRTL_omp_get_num_places,
1483        OMPRTL_omp_get_num_procs,
1484        OMPRTL_omp_get_place_num,
1485        OMPRTL_omp_get_partition_num_places,
1486        OMPRTL_omp_get_partition_place_nums};
1487
1488    // Global-tid is handled separately.
1489    SmallSetVector<Value *, 16> GTIdArgs;
1490    collectGlobalThreadIdArguments(GTIdArgs);
1491    LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
1492                      << " global thread ID arguments\n");
1493
1494    for (Function *F : SCC) {
1495      for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1496        Changed |= deduplicateRuntimeCalls(
1497            *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1498
1499      // __kmpc_global_thread_num is special as we can replace it with an
1500      // argument in enough cases to make it worth trying.
1501      Value *GTIdArg = nullptr;
1502      for (Argument &Arg : F->args())
1503        if (GTIdArgs.count(&Arg)) {
1504          GTIdArg = &Arg;
1505          break;
1506        }
1507      Changed |= deduplicateRuntimeCalls(
1508          *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1509    }
1510
1511    return Changed;
1512  }
1513
1514  /// Tries to remove known runtime symbols that are optional from the module.
1515  bool removeRuntimeSymbols() {
1516    // The RPC client symbol is defined in `libc` and indicates that something
1517    // required an RPC server. If its users were all optimized out then we can
1518    // safely remove it.
1519    // TODO: This should be somewhere more common in the future.
1520    if (GlobalVariable *GV = M.getNamedGlobal("__llvm_libc_rpc_client")) {
1521      if (!GV->getType()->isPointerTy())
1522        return false;
1523
1524      Constant *C = GV->getInitializer();
1525      if (!C)
1526        return false;
1527
1528      // Check to see if the only user of the RPC client is the external handle.
1529      GlobalVariable *Client = dyn_cast<GlobalVariable>(C->stripPointerCasts());
1530      if (!Client || Client->getNumUses() > 1 ||
1531          Client->user_back() != GV->getInitializer())
1532        return false;
1533
1534      Client->replaceAllUsesWith(PoisonValue::get(Client->getType()));
1535      Client->eraseFromParent();
1536
1537      GV->replaceAllUsesWith(PoisonValue::get(GV->getType()));
1538      GV->eraseFromParent();
1539
1540      return true;
1541    }
1542    return false;
1543  }
1544
1545  /// Tries to hide the latency of runtime calls that involve host to
1546  /// device memory transfers by splitting them into their "issue" and "wait"
1547  /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1548  /// moved downards as much as possible. The "issue" issues the memory transfer
1549  /// asynchronously, returning a handle. The "wait" waits in the returned
1550  /// handle for the memory transfer to finish.
1551  bool hideMemTransfersLatency() {
1552    auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1553    bool Changed = false;
1554    auto SplitMemTransfers = [&](Use &U, Function &Decl) {
1555      auto *RTCall = getCallIfRegularCall(U, &RFI);
1556      if (!RTCall)
1557        return false;
1558
1559      OffloadArray OffloadArrays[3];
1560      if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1561        return false;
1562
1563      LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1564
1565      // TODO: Check if can be moved upwards.
1566      bool WasSplit = false;
1567      Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1568      if (WaitMovementPoint)
1569        WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1570
1571      Changed |= WasSplit;
1572      return WasSplit;
1573    };
1574    if (OMPInfoCache.runtimeFnsAvailable(
1575            {OMPRTL___tgt_target_data_begin_mapper_issue,
1576             OMPRTL___tgt_target_data_begin_mapper_wait}))
1577      RFI.foreachUse(SCC, SplitMemTransfers);
1578
1579    return Changed;
1580  }
1581
1582  void analysisGlobalization() {
1583    auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1584
1585    auto CheckGlobalization = [&](Use &U, Function &Decl) {
1586      if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1587        auto Remark = [&](OptimizationRemarkMissed ORM) {
1588          return ORM
1589                 << "Found thread data sharing on the GPU. "
1590                 << "Expect degraded performance due to data globalization.";
1591        };
1592        emitRemark<OptimizationRemarkMissed>(CI, "OMP112", Remark);
1593      }
1594
1595      return false;
1596    };
1597
1598    RFI.foreachUse(SCC, CheckGlobalization);
1599  }
1600
1601  /// Maps the values stored in the offload arrays passed as arguments to
1602  /// \p RuntimeCall into the offload arrays in \p OAs.
1603  bool getValuesInOffloadArrays(CallInst &RuntimeCall,
1604                                MutableArrayRef<OffloadArray> OAs) {
1605    assert(OAs.size() == 3 && "Need space for three offload arrays!");
1606
1607    // A runtime call that involves memory offloading looks something like:
1608    // call void @__tgt_target_data_begin_mapper(arg0, arg1,
1609    //   i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
1610    // ...)
1611    // So, the idea is to access the allocas that allocate space for these
1612    // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
1613    // Therefore:
1614    // i8** %offload_baseptrs.
1615    Value *BasePtrsArg =
1616        RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);
1617    // i8** %offload_ptrs.
1618    Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);
1619    // i8** %offload_sizes.
1620    Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);
1621
1622    // Get values stored in **offload_baseptrs.
1623    auto *V = getUnderlyingObject(BasePtrsArg);
1624    if (!isa<AllocaInst>(V))
1625      return false;
1626    auto *BasePtrsArray = cast<AllocaInst>(V);
1627    if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))
1628      return false;
1629
1630    // Get values stored in **offload_baseptrs.
1631    V = getUnderlyingObject(PtrsArg);
1632    if (!isa<AllocaInst>(V))
1633      return false;
1634    auto *PtrsArray = cast<AllocaInst>(V);
1635    if (!OAs[1].initialize(*PtrsArray, RuntimeCall))
1636      return false;
1637
1638    // Get values stored in **offload_sizes.
1639    V = getUnderlyingObject(SizesArg);
1640    // If it's a [constant] global array don't analyze it.
1641    if (isa<GlobalValue>(V))
1642      return isa<Constant>(V);
1643    if (!isa<AllocaInst>(V))
1644      return false;
1645
1646    auto *SizesArray = cast<AllocaInst>(V);
1647    if (!OAs[2].initialize(*SizesArray, RuntimeCall))
1648      return false;
1649
1650    return true;
1651  }
1652
1653  /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
1654  /// For now this is a way to test that the function getValuesInOffloadArrays
1655  /// is working properly.
1656  /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
1657  void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
1658    assert(OAs.size() == 3 && "There are three offload arrays to debug!");
1659
1660    LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");
1661    std::string ValuesStr;
1662    raw_string_ostream Printer(ValuesStr);
1663    std::string Separator = " --- ";
1664
1665    for (auto *BP : OAs[0].StoredValues) {
1666      BP->print(Printer);
1667      Printer << Separator;
1668    }
1669    LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << Printer.str() << "\n");
1670    ValuesStr.clear();
1671
1672    for (auto *P : OAs[1].StoredValues) {
1673      P->print(Printer);
1674      Printer << Separator;
1675    }
1676    LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << Printer.str() << "\n");
1677    ValuesStr.clear();
1678
1679    for (auto *S : OAs[2].StoredValues) {
1680      S->print(Printer);
1681      Printer << Separator;
1682    }
1683    LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << Printer.str() << "\n");
1684  }
1685
1686  /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1687  /// moved. Returns nullptr if the movement is not possible, or not worth it.
1688  Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
1689    // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1690    //  Make it traverse the CFG.
1691
1692    Instruction *CurrentI = &RuntimeCall;
1693    bool IsWorthIt = false;
1694    while ((CurrentI = CurrentI->getNextNode())) {
1695
1696      // TODO: Once we detect the regions to be offloaded we should use the
1697      //  alias analysis manager to check if CurrentI may modify one of
1698      //  the offloaded regions.
1699      if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
1700        if (IsWorthIt)
1701          return CurrentI;
1702
1703        return nullptr;
1704      }
1705
1706      // FIXME: For now if we move it over anything without side effect
1707      //  is worth it.
1708      IsWorthIt = true;
1709    }
1710
1711    // Return end of BasicBlock.
1712    return RuntimeCall.getParent()->getTerminator();
1713  }
1714
1715  /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
1716  bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
1717                               Instruction &WaitMovementPoint) {
1718    // Create stack allocated handle (__tgt_async_info) at the beginning of the
1719    // function. Used for storing information of the async transfer, allowing to
1720    // wait on it later.
1721    auto &IRBuilder = OMPInfoCache.OMPBuilder;
1722    Function *F = RuntimeCall.getCaller();
1723    BasicBlock &Entry = F->getEntryBlock();
1724    IRBuilder.Builder.SetInsertPoint(&Entry,
1725                                     Entry.getFirstNonPHIOrDbgOrAlloca());
1726    Value *Handle = IRBuilder.Builder.CreateAlloca(
1727        IRBuilder.AsyncInfo, /*ArraySize=*/nullptr, "handle");
1728    Handle =
1729        IRBuilder.Builder.CreateAddrSpaceCast(Handle, IRBuilder.AsyncInfoPtr);
1730
1731    // Add "issue" runtime call declaration:
1732    // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1733    //   i8**, i8**, i64*, i64*)
1734    FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
1735        M, OMPRTL___tgt_target_data_begin_mapper_issue);
1736
1737    // Change RuntimeCall call site for its asynchronous version.
1738    SmallVector<Value *, 16> Args;
1739    for (auto &Arg : RuntimeCall.args())
1740      Args.push_back(Arg.get());
1741    Args.push_back(Handle);
1742
1743    CallInst *IssueCallsite =
1744        CallInst::Create(IssueDecl, Args, /*NameStr=*/"", &RuntimeCall);
1745    OMPInfoCache.setCallingConvention(IssueDecl, IssueCallsite);
1746    RuntimeCall.eraseFromParent();
1747
1748    // Add "wait" runtime call declaration:
1749    // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1750    FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
1751        M, OMPRTL___tgt_target_data_begin_mapper_wait);
1752
1753    Value *WaitParams[2] = {
1754        IssueCallsite->getArgOperand(
1755            OffloadArray::DeviceIDArgNum), // device_id.
1756        Handle                             // handle to wait on.
1757    };
1758    CallInst *WaitCallsite = CallInst::Create(
1759        WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint);
1760    OMPInfoCache.setCallingConvention(WaitDecl, WaitCallsite);
1761
1762    return true;
1763  }
1764
1765  static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
1766                                    bool GlobalOnly, bool &SingleChoice) {
1767    if (CurrentIdent == NextIdent)
1768      return CurrentIdent;
1769
1770    // TODO: Figure out how to actually combine multiple debug locations. For
1771    //       now we just keep an existing one if there is a single choice.
1772    if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1773      SingleChoice = !CurrentIdent;
1774      return NextIdent;
1775    }
1776    return nullptr;
1777  }
1778
1779  /// Return an `struct ident_t*` value that represents the ones used in the
1780  /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
1781  /// return a local `struct ident_t*`. For now, if we cannot find a suitable
1782  /// return value we create one from scratch. We also do not yet combine
1783  /// information, e.g., the source locations, see combinedIdentStruct.
1784  Value *
1785  getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1786                                 Function &F, bool GlobalOnly) {
1787    bool SingleChoice = true;
1788    Value *Ident = nullptr;
1789    auto CombineIdentStruct = [&](Use &U, Function &Caller) {
1790      CallInst *CI = getCallIfRegularCall(U, &RFI);
1791      if (!CI || &F != &Caller)
1792        return false;
1793      Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
1794                                  /* GlobalOnly */ true, SingleChoice);
1795      return false;
1796    };
1797    RFI.foreachUse(SCC, CombineIdentStruct);
1798
1799    if (!Ident || !SingleChoice) {
1800      // The IRBuilder uses the insertion block to get to the module, this is
1801      // unfortunate but we work around it for now.
1802      if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1803        OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
1804            &F.getEntryBlock(), F.getEntryBlock().begin()));
1805      // Create a fallback location if non was found.
1806      // TODO: Use the debug locations of the calls instead.
1807      uint32_t SrcLocStrSize;
1808      Constant *Loc =
1809          OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
1810      Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc, SrcLocStrSize);
1811    }
1812    return Ident;
1813  }
1814
1815  /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
1816  /// \p ReplVal if given.
1817  bool deduplicateRuntimeCalls(Function &F,
1818                               OMPInformationCache::RuntimeFunctionInfo &RFI,
1819                               Value *ReplVal = nullptr) {
1820    auto *UV = RFI.getUseVector(F);
1821    if (!UV || UV->size() + (ReplVal != nullptr) < 2)
1822      return false;
1823
1824    LLVM_DEBUG(
1825        dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
1826               << (ReplVal ? " with an existing value\n" : "\n") << "\n");
1827
1828    assert((!ReplVal || (isa<Argument>(ReplVal) &&
1829                         cast<Argument>(ReplVal)->getParent() == &F)) &&
1830           "Unexpected replacement value!");
1831
1832    // TODO: Use dominance to find a good position instead.
1833    auto CanBeMoved = [this](CallBase &CB) {
1834      unsigned NumArgs = CB.arg_size();
1835      if (NumArgs == 0)
1836        return true;
1837      if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1838        return false;
1839      for (unsigned U = 1; U < NumArgs; ++U)
1840        if (isa<Instruction>(CB.getArgOperand(U)))
1841          return false;
1842      return true;
1843    };
1844
1845    if (!ReplVal) {
1846      auto *DT =
1847          OMPInfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(F);
1848      if (!DT)
1849        return false;
1850      Instruction *IP = nullptr;
1851      for (Use *U : *UV) {
1852        if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1853          if (IP)
1854            IP = DT->findNearestCommonDominator(IP, CI);
1855          else
1856            IP = CI;
1857          if (!CanBeMoved(*CI))
1858            continue;
1859          if (!ReplVal)
1860            ReplVal = CI;
1861        }
1862      }
1863      if (!ReplVal)
1864        return false;
1865      assert(IP && "Expected insertion point!");
1866      cast<Instruction>(ReplVal)->moveBefore(IP);
1867    }
1868
1869    // If we use a call as a replacement value we need to make sure the ident is
1870    // valid at the new location. For now we just pick a global one, either
1871    // existing and used by one of the calls, or created from scratch.
1872    if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1873      if (!CI->arg_empty() &&
1874          CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1875        Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
1876                                                      /* GlobalOnly */ true);
1877        CI->setArgOperand(0, Ident);
1878      }
1879    }
1880
1881    bool Changed = false;
1882    auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
1883      CallInst *CI = getCallIfRegularCall(U, &RFI);
1884      if (!CI || CI == ReplVal || &F != &Caller)
1885        return false;
1886      assert(CI->getCaller() == &F && "Unexpected call!");
1887
1888      auto Remark = [&](OptimizationRemark OR) {
1889        return OR << "OpenMP runtime call "
1890                  << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated.";
1891      };
1892      if (CI->getDebugLoc())
1893        emitRemark<OptimizationRemark>(CI, "OMP170", Remark);
1894      else
1895        emitRemark<OptimizationRemark>(&F, "OMP170", Remark);
1896
1897      CGUpdater.removeCallSite(*CI);
1898      CI->replaceAllUsesWith(ReplVal);
1899      CI->eraseFromParent();
1900      ++NumOpenMPRuntimeCallsDeduplicated;
1901      Changed = true;
1902      return true;
1903    };
1904    RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1905
1906    return Changed;
1907  }
1908
1909  /// Collect arguments that represent the global thread id in \p GTIdArgs.
1910  void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
1911    // TODO: Below we basically perform a fixpoint iteration with a pessimistic
1912    //       initialization. We could define an AbstractAttribute instead and
1913    //       run the Attributor here once it can be run as an SCC pass.
1914
1915    // Helper to check the argument \p ArgNo at all call sites of \p F for
1916    // a GTId.
1917    auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
1918      if (!F.hasLocalLinkage())
1919        return false;
1920      for (Use &U : F.uses()) {
1921        if (CallInst *CI = getCallIfRegularCall(U)) {
1922          Value *ArgOp = CI->getArgOperand(ArgNo);
1923          if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
1924              getCallIfRegularCall(
1925                  *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1926            continue;
1927        }
1928        return false;
1929      }
1930      return true;
1931    };
1932
1933    // Helper to identify uses of a GTId as GTId arguments.
1934    auto AddUserArgs = [&](Value &GTId) {
1935      for (Use &U : GTId.uses())
1936        if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
1937          if (CI->isArgOperand(&U))
1938            if (Function *Callee = CI->getCalledFunction())
1939              if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
1940                GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
1941    };
1942
1943    // The argument users of __kmpc_global_thread_num calls are GTIds.
1944    OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1945        OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1946
1947    GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
1948      if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1949        AddUserArgs(*CI);
1950      return false;
1951    });
1952
1953    // Transitively search for more arguments by looking at the users of the
1954    // ones we know already. During the search the GTIdArgs vector is extended
1955    // so we cannot cache the size nor can we use a range based for.
1956    for (unsigned U = 0; U < GTIdArgs.size(); ++U)
1957      AddUserArgs(*GTIdArgs[U]);
1958  }
1959
1960  /// Kernel (=GPU) optimizations and utility functions
1961  ///
1962  ///{{
1963
1964  /// Cache to remember the unique kernel for a function.
1965  DenseMap<Function *, std::optional<Kernel>> UniqueKernelMap;
1966
1967  /// Find the unique kernel that will execute \p F, if any.
1968  Kernel getUniqueKernelFor(Function &F);
1969
1970  /// Find the unique kernel that will execute \p I, if any.
1971  Kernel getUniqueKernelFor(Instruction &I) {
1972    return getUniqueKernelFor(*I.getFunction());
1973  }
1974
1975  /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
1976  /// the cases we can avoid taking the address of a function.
1977  bool rewriteDeviceCodeStateMachine();
1978
1979  ///
1980  ///}}
1981
1982  /// Emit a remark generically
1983  ///
1984  /// This template function can be used to generically emit a remark. The
1985  /// RemarkKind should be one of the following:
1986  ///   - OptimizationRemark to indicate a successful optimization attempt
1987  ///   - OptimizationRemarkMissed to report a failed optimization attempt
1988  ///   - OptimizationRemarkAnalysis to provide additional information about an
1989  ///     optimization attempt
1990  ///
1991  /// The remark is built using a callback function provided by the caller that
1992  /// takes a RemarkKind as input and returns a RemarkKind.
1993  template <typename RemarkKind, typename RemarkCallBack>
1994  void emitRemark(Instruction *I, StringRef RemarkName,
1995                  RemarkCallBack &&RemarkCB) const {
1996    Function *F = I->getParent()->getParent();
1997    auto &ORE = OREGetter(F);
1998
1999    if (RemarkName.starts_with("OMP"))
2000      ORE.emit([&]() {
2001        return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I))
2002               << " [" << RemarkName << "]";
2003      });
2004    else
2005      ORE.emit(
2006          [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });
2007  }
2008
2009  /// Emit a remark on a function.
2010  template <typename RemarkKind, typename RemarkCallBack>
2011  void emitRemark(Function *F, StringRef RemarkName,
2012                  RemarkCallBack &&RemarkCB) const {
2013    auto &ORE = OREGetter(F);
2014
2015    if (RemarkName.starts_with("OMP"))
2016      ORE.emit([&]() {
2017        return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F))
2018               << " [" << RemarkName << "]";
2019      });
2020    else
2021      ORE.emit(
2022          [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
2023  }
2024
2025  /// The underlying module.
2026  Module &M;
2027
2028  /// The SCC we are operating on.
2029  SmallVectorImpl<Function *> &SCC;
2030
2031  /// Callback to update the call graph, the first argument is a removed call,
2032  /// the second an optional replacement call.
2033  CallGraphUpdater &CGUpdater;
2034
2035  /// Callback to get an OptimizationRemarkEmitter from a Function *
2036  OptimizationRemarkGetter OREGetter;
2037
2038  /// OpenMP-specific information cache. Also Used for Attributor runs.
2039  OMPInformationCache &OMPInfoCache;
2040
2041  /// Attributor instance.
2042  Attributor &A;
2043
2044  /// Helper function to run Attributor on SCC.
2045  bool runAttributor(bool IsModulePass) {
2046    if (SCC.empty())
2047      return false;
2048
2049    registerAAs(IsModulePass);
2050
2051    ChangeStatus Changed = A.run();
2052
2053    LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
2054                      << " functions, result: " << Changed << ".\n");
2055
2056    if (Changed == ChangeStatus::CHANGED)
2057      OMPInfoCache.invalidateAnalyses();
2058
2059    return Changed == ChangeStatus::CHANGED;
2060  }
2061
2062  void registerFoldRuntimeCall(RuntimeFunction RF);
2063
2064  /// Populate the Attributor with abstract attribute opportunities in the
2065  /// functions.
2066  void registerAAs(bool IsModulePass);
2067
2068public:
2069  /// Callback to register AAs for live functions, including internal functions
2070  /// marked live during the traversal.
2071  static void registerAAsForFunction(Attributor &A, const Function &F);
2072};
2073
2074Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
2075  if (OMPInfoCache.CGSCC && !OMPInfoCache.CGSCC->empty() &&
2076      !OMPInfoCache.CGSCC->contains(&F))
2077    return nullptr;
2078
2079  // Use a scope to keep the lifetime of the CachedKernel short.
2080  {
2081    std::optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
2082    if (CachedKernel)
2083      return *CachedKernel;
2084
2085    // TODO: We should use an AA to create an (optimistic and callback
2086    //       call-aware) call graph. For now we stick to simple patterns that
2087    //       are less powerful, basically the worst fixpoint.
2088    if (isOpenMPKernel(F)) {
2089      CachedKernel = Kernel(&F);
2090      return *CachedKernel;
2091    }
2092
2093    CachedKernel = nullptr;
2094    if (!F.hasLocalLinkage()) {
2095
2096      // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
2097      auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2098        return ORA << "Potentially unknown OpenMP target region caller.";
2099      };
2100      emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark);
2101
2102      return nullptr;
2103    }
2104  }
2105
2106  auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
2107    if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
2108      // Allow use in equality comparisons.
2109      if (Cmp->isEquality())
2110        return getUniqueKernelFor(*Cmp);
2111      return nullptr;
2112    }
2113    if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
2114      // Allow direct calls.
2115      if (CB->isCallee(&U))
2116        return getUniqueKernelFor(*CB);
2117
2118      OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2119          OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2120      // Allow the use in __kmpc_parallel_51 calls.
2121      if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
2122        return getUniqueKernelFor(*CB);
2123      return nullptr;
2124    }
2125    // Disallow every other use.
2126    return nullptr;
2127  };
2128
2129  // TODO: In the future we want to track more than just a unique kernel.
2130  SmallPtrSet<Kernel, 2> PotentialKernels;
2131  OMPInformationCache::foreachUse(F, [&](const Use &U) {
2132    PotentialKernels.insert(GetUniqueKernelForUse(U));
2133  });
2134
2135  Kernel K = nullptr;
2136  if (PotentialKernels.size() == 1)
2137    K = *PotentialKernels.begin();
2138
2139  // Cache the result.
2140  UniqueKernelMap[&F] = K;
2141
2142  return K;
2143}
2144
2145bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
2146  OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2147      OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2148
2149  bool Changed = false;
2150  if (!KernelParallelRFI)
2151    return Changed;
2152
2153  // If we have disabled state machine changes, exit
2154  if (DisableOpenMPOptStateMachineRewrite)
2155    return Changed;
2156
2157  for (Function *F : SCC) {
2158
2159    // Check if the function is a use in a __kmpc_parallel_51 call at
2160    // all.
2161    bool UnknownUse = false;
2162    bool KernelParallelUse = false;
2163    unsigned NumDirectCalls = 0;
2164
2165    SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
2166    OMPInformationCache::foreachUse(*F, [&](Use &U) {
2167      if (auto *CB = dyn_cast<CallBase>(U.getUser()))
2168        if (CB->isCallee(&U)) {
2169          ++NumDirectCalls;
2170          return;
2171        }
2172
2173      if (isa<ICmpInst>(U.getUser())) {
2174        ToBeReplacedStateMachineUses.push_back(&U);
2175        return;
2176      }
2177
2178      // Find wrapper functions that represent parallel kernels.
2179      CallInst *CI =
2180          OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
2181      const unsigned int WrapperFunctionArgNo = 6;
2182      if (!KernelParallelUse && CI &&
2183          CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
2184        KernelParallelUse = true;
2185        ToBeReplacedStateMachineUses.push_back(&U);
2186        return;
2187      }
2188      UnknownUse = true;
2189    });
2190
2191    // Do not emit a remark if we haven't seen a __kmpc_parallel_51
2192    // use.
2193    if (!KernelParallelUse)
2194      continue;
2195
2196    // If this ever hits, we should investigate.
2197    // TODO: Checking the number of uses is not a necessary restriction and
2198    // should be lifted.
2199    if (UnknownUse || NumDirectCalls != 1 ||
2200        ToBeReplacedStateMachineUses.size() > 2) {
2201      auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2202        return ORA << "Parallel region is used in "
2203                   << (UnknownUse ? "unknown" : "unexpected")
2204                   << " ways. Will not attempt to rewrite the state machine.";
2205      };
2206      emitRemark<OptimizationRemarkAnalysis>(F, "OMP101", Remark);
2207      continue;
2208    }
2209
2210    // Even if we have __kmpc_parallel_51 calls, we (for now) give
2211    // up if the function is not called from a unique kernel.
2212    Kernel K = getUniqueKernelFor(*F);
2213    if (!K) {
2214      auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2215        return ORA << "Parallel region is not called from a unique kernel. "
2216                      "Will not attempt to rewrite the state machine.";
2217      };
2218      emitRemark<OptimizationRemarkAnalysis>(F, "OMP102", Remark);
2219      continue;
2220    }
2221
2222    // We now know F is a parallel body function called only from the kernel K.
2223    // We also identified the state machine uses in which we replace the
2224    // function pointer by a new global symbol for identification purposes. This
2225    // ensures only direct calls to the function are left.
2226
2227    Module &M = *F->getParent();
2228    Type *Int8Ty = Type::getInt8Ty(M.getContext());
2229
2230    auto *ID = new GlobalVariable(
2231        M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
2232        UndefValue::get(Int8Ty), F->getName() + ".ID");
2233
2234    for (Use *U : ToBeReplacedStateMachineUses)
2235      U->set(ConstantExpr::getPointerBitCastOrAddrSpaceCast(
2236          ID, U->get()->getType()));
2237
2238    ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2239
2240    Changed = true;
2241  }
2242
2243  return Changed;
2244}
2245
2246/// Abstract Attribute for tracking ICV values.
2247struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
2248  using Base = StateWrapper<BooleanState, AbstractAttribute>;
2249  AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2250
2251  /// Returns true if value is assumed to be tracked.
2252  bool isAssumedTracked() const { return getAssumed(); }
2253
2254  /// Returns true if value is known to be tracked.
2255  bool isKnownTracked() const { return getAssumed(); }
2256
2257  /// Create an abstract attribute biew for the position \p IRP.
2258  static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
2259
2260  /// Return the value with which \p I can be replaced for specific \p ICV.
2261  virtual std::optional<Value *> getReplacementValue(InternalControlVar ICV,
2262                                                     const Instruction *I,
2263                                                     Attributor &A) const {
2264    return std::nullopt;
2265  }
2266
2267  /// Return an assumed unique ICV value if a single candidate is found. If
2268  /// there cannot be one, return a nullptr. If it is not clear yet, return
2269  /// std::nullopt.
2270  virtual std::optional<Value *>
2271  getUniqueReplacementValue(InternalControlVar ICV) const = 0;
2272
2273  // Currently only nthreads is being tracked.
2274  // this array will only grow with time.
2275  InternalControlVar TrackableICVs[1] = {ICV_nthreads};
2276
2277  /// See AbstractAttribute::getName()
2278  const std::string getName() const override { return "AAICVTracker"; }
2279
2280  /// See AbstractAttribute::getIdAddr()
2281  const char *getIdAddr() const override { return &ID; }
2282
2283  /// This function should return true if the type of the \p AA is AAICVTracker
2284  static bool classof(const AbstractAttribute *AA) {
2285    return (AA->getIdAddr() == &ID);
2286  }
2287
2288  static const char ID;
2289};
2290
2291struct AAICVTrackerFunction : public AAICVTracker {
2292  AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
2293      : AAICVTracker(IRP, A) {}
2294
2295  // FIXME: come up with better string.
2296  const std::string getAsStr(Attributor *) const override {
2297    return "ICVTrackerFunction";
2298  }
2299
2300  // FIXME: come up with some stats.
2301  void trackStatistics() const override {}
2302
2303  /// We don't manifest anything for this AA.
2304  ChangeStatus manifest(Attributor &A) override {
2305    return ChangeStatus::UNCHANGED;
2306  }
2307
2308  // Map of ICV to their values at specific program point.
2309  EnumeratedArray<DenseMap<Instruction *, Value *>, InternalControlVar,
2310                  InternalControlVar::ICV___last>
2311      ICVReplacementValuesMap;
2312
2313  ChangeStatus updateImpl(Attributor &A) override {
2314    ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
2315
2316    Function *F = getAnchorScope();
2317
2318    auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2319
2320    for (InternalControlVar ICV : TrackableICVs) {
2321      auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2322
2323      auto &ValuesMap = ICVReplacementValuesMap[ICV];
2324      auto TrackValues = [&](Use &U, Function &) {
2325        CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2326        if (!CI)
2327          return false;
2328
2329        // FIXME: handle setters with more that 1 arguments.
2330        /// Track new value.
2331        if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
2332          HasChanged = ChangeStatus::CHANGED;
2333
2334        return false;
2335      };
2336
2337      auto CallCheck = [&](Instruction &I) {
2338        std::optional<Value *> ReplVal = getValueForCall(A, I, ICV);
2339        if (ReplVal && ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
2340          HasChanged = ChangeStatus::CHANGED;
2341
2342        return true;
2343      };
2344
2345      // Track all changes of an ICV.
2346      SetterRFI.foreachUse(TrackValues, F);
2347
2348      bool UsedAssumedInformation = false;
2349      A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
2350                                UsedAssumedInformation,
2351                                /* CheckBBLivenessOnly */ true);
2352
2353      /// TODO: Figure out a way to avoid adding entry in
2354      /// ICVReplacementValuesMap
2355      Instruction *Entry = &F->getEntryBlock().front();
2356      if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry))
2357        ValuesMap.insert(std::make_pair(Entry, nullptr));
2358    }
2359
2360    return HasChanged;
2361  }
2362
2363  /// Helper to check if \p I is a call and get the value for it if it is
2364  /// unique.
2365  std::optional<Value *> getValueForCall(Attributor &A, const Instruction &I,
2366                                         InternalControlVar &ICV) const {
2367
2368    const auto *CB = dyn_cast<CallBase>(&I);
2369    if (!CB || CB->hasFnAttr("no_openmp") ||
2370        CB->hasFnAttr("no_openmp_routines"))
2371      return std::nullopt;
2372
2373    auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2374    auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2375    auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2376    Function *CalledFunction = CB->getCalledFunction();
2377
2378    // Indirect call, assume ICV changes.
2379    if (CalledFunction == nullptr)
2380      return nullptr;
2381    if (CalledFunction == GetterRFI.Declaration)
2382      return std::nullopt;
2383    if (CalledFunction == SetterRFI.Declaration) {
2384      if (ICVReplacementValuesMap[ICV].count(&I))
2385        return ICVReplacementValuesMap[ICV].lookup(&I);
2386
2387      return nullptr;
2388    }
2389
2390    // Since we don't know, assume it changes the ICV.
2391    if (CalledFunction->isDeclaration())
2392      return nullptr;
2393
2394    const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2395        *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
2396
2397    if (ICVTrackingAA->isAssumedTracked()) {
2398      std::optional<Value *> URV =
2399          ICVTrackingAA->getUniqueReplacementValue(ICV);
2400      if (!URV || (*URV && AA::isValidAtPosition(AA::ValueAndContext(**URV, I),
2401                                                 OMPInfoCache)))
2402        return URV;
2403    }
2404
2405    // If we don't know, assume it changes.
2406    return nullptr;
2407  }
2408
2409  // We don't check unique value for a function, so return std::nullopt.
2410  std::optional<Value *>
2411  getUniqueReplacementValue(InternalControlVar ICV) const override {
2412    return std::nullopt;
2413  }
2414
2415  /// Return the value with which \p I can be replaced for specific \p ICV.
2416  std::optional<Value *> getReplacementValue(InternalControlVar ICV,
2417                                             const Instruction *I,
2418                                             Attributor &A) const override {
2419    const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2420    if (ValuesMap.count(I))
2421      return ValuesMap.lookup(I);
2422
2423    SmallVector<const Instruction *, 16> Worklist;
2424    SmallPtrSet<const Instruction *, 16> Visited;
2425    Worklist.push_back(I);
2426
2427    std::optional<Value *> ReplVal;
2428
2429    while (!Worklist.empty()) {
2430      const Instruction *CurrInst = Worklist.pop_back_val();
2431      if (!Visited.insert(CurrInst).second)
2432        continue;
2433
2434      const BasicBlock *CurrBB = CurrInst->getParent();
2435
2436      // Go up and look for all potential setters/calls that might change the
2437      // ICV.
2438      while ((CurrInst = CurrInst->getPrevNode())) {
2439        if (ValuesMap.count(CurrInst)) {
2440          std::optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2441          // Unknown value, track new.
2442          if (!ReplVal) {
2443            ReplVal = NewReplVal;
2444            break;
2445          }
2446
2447          // If we found a new value, we can't know the icv value anymore.
2448          if (NewReplVal)
2449            if (ReplVal != NewReplVal)
2450              return nullptr;
2451
2452          break;
2453        }
2454
2455        std::optional<Value *> NewReplVal = getValueForCall(A, *CurrInst, ICV);
2456        if (!NewReplVal)
2457          continue;
2458
2459        // Unknown value, track new.
2460        if (!ReplVal) {
2461          ReplVal = NewReplVal;
2462          break;
2463        }
2464
2465        // if (NewReplVal.hasValue())
2466        // We found a new value, we can't know the icv value anymore.
2467        if (ReplVal != NewReplVal)
2468          return nullptr;
2469      }
2470
2471      // If we are in the same BB and we have a value, we are done.
2472      if (CurrBB == I->getParent() && ReplVal)
2473        return ReplVal;
2474
2475      // Go through all predecessors and add terminators for analysis.
2476      for (const BasicBlock *Pred : predecessors(CurrBB))
2477        if (const Instruction *Terminator = Pred->getTerminator())
2478          Worklist.push_back(Terminator);
2479    }
2480
2481    return ReplVal;
2482  }
2483};
2484
2485struct AAICVTrackerFunctionReturned : AAICVTracker {
2486  AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
2487      : AAICVTracker(IRP, A) {}
2488
2489  // FIXME: come up with better string.
2490  const std::string getAsStr(Attributor *) const override {
2491    return "ICVTrackerFunctionReturned";
2492  }
2493
2494  // FIXME: come up with some stats.
2495  void trackStatistics() const override {}
2496
2497  /// We don't manifest anything for this AA.
2498  ChangeStatus manifest(Attributor &A) override {
2499    return ChangeStatus::UNCHANGED;
2500  }
2501
2502  // Map of ICV to their values at specific program point.
2503  EnumeratedArray<std::optional<Value *>, InternalControlVar,
2504                  InternalControlVar::ICV___last>
2505      ICVReplacementValuesMap;
2506
2507  /// Return the value with which \p I can be replaced for specific \p ICV.
2508  std::optional<Value *>
2509  getUniqueReplacementValue(InternalControlVar ICV) const override {
2510    return ICVReplacementValuesMap[ICV];
2511  }
2512
2513  ChangeStatus updateImpl(Attributor &A) override {
2514    ChangeStatus Changed = ChangeStatus::UNCHANGED;
2515    const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2516        *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2517
2518    if (!ICVTrackingAA->isAssumedTracked())
2519      return indicatePessimisticFixpoint();
2520
2521    for (InternalControlVar ICV : TrackableICVs) {
2522      std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2523      std::optional<Value *> UniqueICVValue;
2524
2525      auto CheckReturnInst = [&](Instruction &I) {
2526        std::optional<Value *> NewReplVal =
2527            ICVTrackingAA->getReplacementValue(ICV, &I, A);
2528
2529        // If we found a second ICV value there is no unique returned value.
2530        if (UniqueICVValue && UniqueICVValue != NewReplVal)
2531          return false;
2532
2533        UniqueICVValue = NewReplVal;
2534
2535        return true;
2536      };
2537
2538      bool UsedAssumedInformation = false;
2539      if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
2540                                     UsedAssumedInformation,
2541                                     /* CheckBBLivenessOnly */ true))
2542        UniqueICVValue = nullptr;
2543
2544      if (UniqueICVValue == ReplVal)
2545        continue;
2546
2547      ReplVal = UniqueICVValue;
2548      Changed = ChangeStatus::CHANGED;
2549    }
2550
2551    return Changed;
2552  }
2553};
2554
2555struct AAICVTrackerCallSite : AAICVTracker {
2556  AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
2557      : AAICVTracker(IRP, A) {}
2558
2559  void initialize(Attributor &A) override {
2560    assert(getAnchorScope() && "Expected anchor function");
2561
2562    // We only initialize this AA for getters, so we need to know which ICV it
2563    // gets.
2564    auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2565    for (InternalControlVar ICV : TrackableICVs) {
2566      auto ICVInfo = OMPInfoCache.ICVs[ICV];
2567      auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2568      if (Getter.Declaration == getAssociatedFunction()) {
2569        AssociatedICV = ICVInfo.Kind;
2570        return;
2571      }
2572    }
2573
2574    /// Unknown ICV.
2575    indicatePessimisticFixpoint();
2576  }
2577
2578  ChangeStatus manifest(Attributor &A) override {
2579    if (!ReplVal || !*ReplVal)
2580      return ChangeStatus::UNCHANGED;
2581
2582    A.changeAfterManifest(IRPosition::inst(*getCtxI()), **ReplVal);
2583    A.deleteAfterManifest(*getCtxI());
2584
2585    return ChangeStatus::CHANGED;
2586  }
2587
2588  // FIXME: come up with better string.
2589  const std::string getAsStr(Attributor *) const override {
2590    return "ICVTrackerCallSite";
2591  }
2592
2593  // FIXME: come up with some stats.
2594  void trackStatistics() const override {}
2595
2596  InternalControlVar AssociatedICV;
2597  std::optional<Value *> ReplVal;
2598
2599  ChangeStatus updateImpl(Attributor &A) override {
2600    const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2601        *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2602
2603    // We don't have any information, so we assume it changes the ICV.
2604    if (!ICVTrackingAA->isAssumedTracked())
2605      return indicatePessimisticFixpoint();
2606
2607    std::optional<Value *> NewReplVal =
2608        ICVTrackingAA->getReplacementValue(AssociatedICV, getCtxI(), A);
2609
2610    if (ReplVal == NewReplVal)
2611      return ChangeStatus::UNCHANGED;
2612
2613    ReplVal = NewReplVal;
2614    return ChangeStatus::CHANGED;
2615  }
2616
2617  // Return the value with which associated value can be replaced for specific
2618  // \p ICV.
2619  std::optional<Value *>
2620  getUniqueReplacementValue(InternalControlVar ICV) const override {
2621    return ReplVal;
2622  }
2623};
2624
2625struct AAICVTrackerCallSiteReturned : AAICVTracker {
2626  AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
2627      : AAICVTracker(IRP, A) {}
2628
2629  // FIXME: come up with better string.
2630  const std::string getAsStr(Attributor *) const override {
2631    return "ICVTrackerCallSiteReturned";
2632  }
2633
2634  // FIXME: come up with some stats.
2635  void trackStatistics() const override {}
2636
2637  /// We don't manifest anything for this AA.
2638  ChangeStatus manifest(Attributor &A) override {
2639    return ChangeStatus::UNCHANGED;
2640  }
2641
2642  // Map of ICV to their values at specific program point.
2643  EnumeratedArray<std::optional<Value *>, InternalControlVar,
2644                  InternalControlVar::ICV___last>
2645      ICVReplacementValuesMap;
2646
2647  /// Return the value with which associated value can be replaced for specific
2648  /// \p ICV.
2649  std::optional<Value *>
2650  getUniqueReplacementValue(InternalControlVar ICV) const override {
2651    return ICVReplacementValuesMap[ICV];
2652  }
2653
2654  ChangeStatus updateImpl(Attributor &A) override {
2655    ChangeStatus Changed = ChangeStatus::UNCHANGED;
2656    const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2657        *this, IRPosition::returned(*getAssociatedFunction()),
2658        DepClassTy::REQUIRED);
2659
2660    // We don't have any information, so we assume it changes the ICV.
2661    if (!ICVTrackingAA->isAssumedTracked())
2662      return indicatePessimisticFixpoint();
2663
2664    for (InternalControlVar ICV : TrackableICVs) {
2665      std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2666      std::optional<Value *> NewReplVal =
2667          ICVTrackingAA->getUniqueReplacementValue(ICV);
2668
2669      if (ReplVal == NewReplVal)
2670        continue;
2671
2672      ReplVal = NewReplVal;
2673      Changed = ChangeStatus::CHANGED;
2674    }
2675    return Changed;
2676  }
2677};
2678
2679/// Determines if \p BB exits the function unconditionally itself or reaches a
2680/// block that does through only unique successors.
2681static bool hasFunctionEndAsUniqueSuccessor(const BasicBlock *BB) {
2682  if (succ_empty(BB))
2683    return true;
2684  const BasicBlock *const Successor = BB->getUniqueSuccessor();
2685  if (!Successor)
2686    return false;
2687  return hasFunctionEndAsUniqueSuccessor(Successor);
2688}
2689
2690struct AAExecutionDomainFunction : public AAExecutionDomain {
2691  AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
2692      : AAExecutionDomain(IRP, A) {}
2693
2694  ~AAExecutionDomainFunction() { delete RPOT; }
2695
2696  void initialize(Attributor &A) override {
2697    Function *F = getAnchorScope();
2698    assert(F && "Expected anchor function");
2699    RPOT = new ReversePostOrderTraversal<Function *>(F);
2700  }
2701
2702  const std::string getAsStr(Attributor *) const override {
2703    unsigned TotalBlocks = 0, InitialThreadBlocks = 0, AlignedBlocks = 0;
2704    for (auto &It : BEDMap) {
2705      if (!It.getFirst())
2706        continue;
2707      TotalBlocks++;
2708      InitialThreadBlocks += It.getSecond().IsExecutedByInitialThreadOnly;
2709      AlignedBlocks += It.getSecond().IsReachedFromAlignedBarrierOnly &&
2710                       It.getSecond().IsReachingAlignedBarrierOnly;
2711    }
2712    return "[AAExecutionDomain] " + std::to_string(InitialThreadBlocks) + "/" +
2713           std::to_string(AlignedBlocks) + " of " +
2714           std::to_string(TotalBlocks) +
2715           " executed by initial thread / aligned";
2716  }
2717
2718  /// See AbstractAttribute::trackStatistics().
2719  void trackStatistics() const override {}
2720
2721  ChangeStatus manifest(Attributor &A) override {
2722    LLVM_DEBUG({
2723      for (const BasicBlock &BB : *getAnchorScope()) {
2724        if (!isExecutedByInitialThreadOnly(BB))
2725          continue;
2726        dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "
2727               << BB.getName() << " is executed by a single thread.\n";
2728      }
2729    });
2730
2731    ChangeStatus Changed = ChangeStatus::UNCHANGED;
2732
2733    if (DisableOpenMPOptBarrierElimination)
2734      return Changed;
2735
2736    SmallPtrSet<CallBase *, 16> DeletedBarriers;
2737    auto HandleAlignedBarrier = [&](CallBase *CB) {
2738      const ExecutionDomainTy &ED = CB ? CEDMap[{CB, PRE}] : BEDMap[nullptr];
2739      if (!ED.IsReachedFromAlignedBarrierOnly ||
2740          ED.EncounteredNonLocalSideEffect)
2741        return;
2742      if (!ED.EncounteredAssumes.empty() && !A.isModulePass())
2743        return;
2744
2745      // We can remove this barrier, if it is one, or aligned barriers reaching
2746      // the kernel end (if CB is nullptr). Aligned barriers reaching the kernel
2747      // end should only be removed if the kernel end is their unique successor;
2748      // otherwise, they may have side-effects that aren't accounted for in the
2749      // kernel end in their other successors. If those barriers have other
2750      // barriers reaching them, those can be transitively removed as well as
2751      // long as the kernel end is also their unique successor.
2752      if (CB) {
2753        DeletedBarriers.insert(CB);
2754        A.deleteAfterManifest(*CB);
2755        ++NumBarriersEliminated;
2756        Changed = ChangeStatus::CHANGED;
2757      } else if (!ED.AlignedBarriers.empty()) {
2758        Changed = ChangeStatus::CHANGED;
2759        SmallVector<CallBase *> Worklist(ED.AlignedBarriers.begin(),
2760                                         ED.AlignedBarriers.end());
2761        SmallSetVector<CallBase *, 16> Visited;
2762        while (!Worklist.empty()) {
2763          CallBase *LastCB = Worklist.pop_back_val();
2764          if (!Visited.insert(LastCB))
2765            continue;
2766          if (LastCB->getFunction() != getAnchorScope())
2767            continue;
2768          if (!hasFunctionEndAsUniqueSuccessor(LastCB->getParent()))
2769            continue;
2770          if (!DeletedBarriers.count(LastCB)) {
2771            ++NumBarriersEliminated;
2772            A.deleteAfterManifest(*LastCB);
2773            continue;
2774          }
2775          // The final aligned barrier (LastCB) reaching the kernel end was
2776          // removed already. This means we can go one step further and remove
2777          // the barriers encoutered last before (LastCB).
2778          const ExecutionDomainTy &LastED = CEDMap[{LastCB, PRE}];
2779          Worklist.append(LastED.AlignedBarriers.begin(),
2780                          LastED.AlignedBarriers.end());
2781        }
2782      }
2783
2784      // If we actually eliminated a barrier we need to eliminate the associated
2785      // llvm.assumes as well to avoid creating UB.
2786      if (!ED.EncounteredAssumes.empty() && (CB || !ED.AlignedBarriers.empty()))
2787        for (auto *AssumeCB : ED.EncounteredAssumes)
2788          A.deleteAfterManifest(*AssumeCB);
2789    };
2790
2791    for (auto *CB : AlignedBarriers)
2792      HandleAlignedBarrier(CB);
2793
2794    // Handle the "kernel end barrier" for kernels too.
2795    if (omp::isOpenMPKernel(*getAnchorScope()))
2796      HandleAlignedBarrier(nullptr);
2797
2798    return Changed;
2799  }
2800
2801  bool isNoOpFence(const FenceInst &FI) const override {
2802    return getState().isValidState() && !NonNoOpFences.count(&FI);
2803  }
2804
2805  /// Merge barrier and assumption information from \p PredED into the successor
2806  /// \p ED.
2807  void
2808  mergeInPredecessorBarriersAndAssumptions(Attributor &A, ExecutionDomainTy &ED,
2809                                           const ExecutionDomainTy &PredED);
2810
2811  /// Merge all information from \p PredED into the successor \p ED. If
2812  /// \p InitialEdgeOnly is set, only the initial edge will enter the block
2813  /// represented by \p ED from this predecessor.
2814  bool mergeInPredecessor(Attributor &A, ExecutionDomainTy &ED,
2815                          const ExecutionDomainTy &PredED,
2816                          bool InitialEdgeOnly = false);
2817
2818  /// Accumulate information for the entry block in \p EntryBBED.
2819  bool handleCallees(Attributor &A, ExecutionDomainTy &EntryBBED);
2820
2821  /// See AbstractAttribute::updateImpl.
2822  ChangeStatus updateImpl(Attributor &A) override;
2823
2824  /// Query interface, see AAExecutionDomain
2825  ///{
2826  bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
2827    if (!isValidState())
2828      return false;
2829    assert(BB.getParent() == getAnchorScope() && "Block is out of scope!");
2830    return BEDMap.lookup(&BB).IsExecutedByInitialThreadOnly;
2831  }
2832
2833  bool isExecutedInAlignedRegion(Attributor &A,
2834                                 const Instruction &I) const override {
2835    assert(I.getFunction() == getAnchorScope() &&
2836           "Instruction is out of scope!");
2837    if (!isValidState())
2838      return false;
2839
2840    bool ForwardIsOk = true;
2841    const Instruction *CurI;
2842
2843    // Check forward until a call or the block end is reached.
2844    CurI = &I;
2845    do {
2846      auto *CB = dyn_cast<CallBase>(CurI);
2847      if (!CB)
2848        continue;
2849      if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB)))
2850        return true;
2851      const auto &It = CEDMap.find({CB, PRE});
2852      if (It == CEDMap.end())
2853        continue;
2854      if (!It->getSecond().IsReachingAlignedBarrierOnly)
2855        ForwardIsOk = false;
2856      break;
2857    } while ((CurI = CurI->getNextNonDebugInstruction()));
2858
2859    if (!CurI && !BEDMap.lookup(I.getParent()).IsReachingAlignedBarrierOnly)
2860      ForwardIsOk = false;
2861
2862    // Check backward until a call or the block beginning is reached.
2863    CurI = &I;
2864    do {
2865      auto *CB = dyn_cast<CallBase>(CurI);
2866      if (!CB)
2867        continue;
2868      if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB)))
2869        return true;
2870      const auto &It = CEDMap.find({CB, POST});
2871      if (It == CEDMap.end())
2872        continue;
2873      if (It->getSecond().IsReachedFromAlignedBarrierOnly)
2874        break;
2875      return false;
2876    } while ((CurI = CurI->getPrevNonDebugInstruction()));
2877
2878    // Delayed decision on the forward pass to allow aligned barrier detection
2879    // in the backwards traversal.
2880    if (!ForwardIsOk)
2881      return false;
2882
2883    if (!CurI) {
2884      const BasicBlock *BB = I.getParent();
2885      if (BB == &BB->getParent()->getEntryBlock())
2886        return BEDMap.lookup(nullptr).IsReachedFromAlignedBarrierOnly;
2887      if (!llvm::all_of(predecessors(BB), [&](const BasicBlock *PredBB) {
2888            return BEDMap.lookup(PredBB).IsReachedFromAlignedBarrierOnly;
2889          })) {
2890        return false;
2891      }
2892    }
2893
2894    // On neither traversal we found a anything but aligned barriers.
2895    return true;
2896  }
2897
2898  ExecutionDomainTy getExecutionDomain(const BasicBlock &BB) const override {
2899    assert(isValidState() &&
2900           "No request should be made against an invalid state!");
2901    return BEDMap.lookup(&BB);
2902  }
2903  std::pair<ExecutionDomainTy, ExecutionDomainTy>
2904  getExecutionDomain(const CallBase &CB) const override {
2905    assert(isValidState() &&
2906           "No request should be made against an invalid state!");
2907    return {CEDMap.lookup({&CB, PRE}), CEDMap.lookup({&CB, POST})};
2908  }
2909  ExecutionDomainTy getFunctionExecutionDomain() const override {
2910    assert(isValidState() &&
2911           "No request should be made against an invalid state!");
2912    return InterProceduralED;
2913  }
2914  ///}
2915
2916  // Check if the edge into the successor block contains a condition that only
2917  // lets the main thread execute it.
2918  static bool isInitialThreadOnlyEdge(Attributor &A, BranchInst *Edge,
2919                                      BasicBlock &SuccessorBB) {
2920    if (!Edge || !Edge->isConditional())
2921      return false;
2922    if (Edge->getSuccessor(0) != &SuccessorBB)
2923      return false;
2924
2925    auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());
2926    if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
2927      return false;
2928
2929    ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
2930    if (!C)
2931      return false;
2932
2933    // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!)
2934    if (C->isAllOnesValue()) {
2935      auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
2936      auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2937      auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2938      CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
2939      if (!CB)
2940        return false;
2941      ConstantStruct *KernelEnvC =
2942          KernelInfo::getKernelEnvironementFromKernelInitCB(CB);
2943      ConstantInt *ExecModeC =
2944          KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
2945      return ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC;
2946    }
2947
2948    if (C->isZero()) {
2949      // Match: 0 == llvm.nvvm.read.ptx.sreg.tid.x()
2950      if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2951        if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
2952          return true;
2953
2954      // Match: 0 == llvm.amdgcn.workitem.id.x()
2955      if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2956        if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
2957          return true;
2958    }
2959
2960    return false;
2961  };
2962
2963  /// Mapping containing information about the function for other AAs.
2964  ExecutionDomainTy InterProceduralED;
2965
2966  enum Direction { PRE = 0, POST = 1 };
2967  /// Mapping containing information per block.
2968  DenseMap<const BasicBlock *, ExecutionDomainTy> BEDMap;
2969  DenseMap<PointerIntPair<const CallBase *, 1, Direction>, ExecutionDomainTy>
2970      CEDMap;
2971  SmallSetVector<CallBase *, 16> AlignedBarriers;
2972
2973  ReversePostOrderTraversal<Function *> *RPOT = nullptr;
2974
2975  /// Set \p R to \V and report true if that changed \p R.
2976  static bool setAndRecord(bool &R, bool V) {
2977    bool Eq = (R == V);
2978    R = V;
2979    return !Eq;
2980  }
2981
2982  /// Collection of fences known to be non-no-opt. All fences not in this set
2983  /// can be assumed no-opt.
2984  SmallPtrSet<const FenceInst *, 8> NonNoOpFences;
2985};
2986
2987void AAExecutionDomainFunction::mergeInPredecessorBarriersAndAssumptions(
2988    Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED) {
2989  for (auto *EA : PredED.EncounteredAssumes)
2990    ED.addAssumeInst(A, *EA);
2991
2992  for (auto *AB : PredED.AlignedBarriers)
2993    ED.addAlignedBarrier(A, *AB);
2994}
2995
2996bool AAExecutionDomainFunction::mergeInPredecessor(
2997    Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED,
2998    bool InitialEdgeOnly) {
2999
3000  bool Changed = false;
3001  Changed |=
3002      setAndRecord(ED.IsExecutedByInitialThreadOnly,
3003                   InitialEdgeOnly || (PredED.IsExecutedByInitialThreadOnly &&
3004                                       ED.IsExecutedByInitialThreadOnly));
3005
3006  Changed |= setAndRecord(ED.IsReachedFromAlignedBarrierOnly,
3007                          ED.IsReachedFromAlignedBarrierOnly &&
3008                              PredED.IsReachedFromAlignedBarrierOnly);
3009  Changed |= setAndRecord(ED.EncounteredNonLocalSideEffect,
3010                          ED.EncounteredNonLocalSideEffect |
3011                              PredED.EncounteredNonLocalSideEffect);
3012  // Do not track assumptions and barriers as part of Changed.
3013  if (ED.IsReachedFromAlignedBarrierOnly)
3014    mergeInPredecessorBarriersAndAssumptions(A, ED, PredED);
3015  else
3016    ED.clearAssumeInstAndAlignedBarriers();
3017  return Changed;
3018}
3019
3020bool AAExecutionDomainFunction::handleCallees(Attributor &A,
3021                                              ExecutionDomainTy &EntryBBED) {
3022  SmallVector<std::pair<ExecutionDomainTy, ExecutionDomainTy>, 4> CallSiteEDs;
3023  auto PredForCallSite = [&](AbstractCallSite ACS) {
3024    const auto *EDAA = A.getAAFor<AAExecutionDomain>(
3025        *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
3026        DepClassTy::OPTIONAL);
3027    if (!EDAA || !EDAA->getState().isValidState())
3028      return false;
3029    CallSiteEDs.emplace_back(
3030        EDAA->getExecutionDomain(*cast<CallBase>(ACS.getInstruction())));
3031    return true;
3032  };
3033
3034  ExecutionDomainTy ExitED;
3035  bool AllCallSitesKnown;
3036  if (A.checkForAllCallSites(PredForCallSite, *this,
3037                             /* RequiresAllCallSites */ true,
3038                             AllCallSitesKnown)) {
3039    for (const auto &[CSInED, CSOutED] : CallSiteEDs) {
3040      mergeInPredecessor(A, EntryBBED, CSInED);
3041      ExitED.IsReachingAlignedBarrierOnly &=
3042          CSOutED.IsReachingAlignedBarrierOnly;
3043    }
3044
3045  } else {
3046    // We could not find all predecessors, so this is either a kernel or a
3047    // function with external linkage (or with some other weird uses).
3048    if (omp::isOpenMPKernel(*getAnchorScope())) {
3049      EntryBBED.IsExecutedByInitialThreadOnly = false;
3050      EntryBBED.IsReachedFromAlignedBarrierOnly = true;
3051      EntryBBED.EncounteredNonLocalSideEffect = false;
3052      ExitED.IsReachingAlignedBarrierOnly = false;
3053    } else {
3054      EntryBBED.IsExecutedByInitialThreadOnly = false;
3055      EntryBBED.IsReachedFromAlignedBarrierOnly = false;
3056      EntryBBED.EncounteredNonLocalSideEffect = true;
3057      ExitED.IsReachingAlignedBarrierOnly = false;
3058    }
3059  }
3060
3061  bool Changed = false;
3062  auto &FnED = BEDMap[nullptr];
3063  Changed |= setAndRecord(FnED.IsReachedFromAlignedBarrierOnly,
3064                          FnED.IsReachedFromAlignedBarrierOnly &
3065                              EntryBBED.IsReachedFromAlignedBarrierOnly);
3066  Changed |= setAndRecord(FnED.IsReachingAlignedBarrierOnly,
3067                          FnED.IsReachingAlignedBarrierOnly &
3068                              ExitED.IsReachingAlignedBarrierOnly);
3069  Changed |= setAndRecord(FnED.IsExecutedByInitialThreadOnly,
3070                          EntryBBED.IsExecutedByInitialThreadOnly);
3071  return Changed;
3072}
3073
3074ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
3075
3076  bool Changed = false;
3077
3078  // Helper to deal with an aligned barrier encountered during the forward
3079  // traversal. \p CB is the aligned barrier, \p ED is the execution domain when
3080  // it was encountered.
3081  auto HandleAlignedBarrier = [&](CallBase &CB, ExecutionDomainTy &ED) {
3082    Changed |= AlignedBarriers.insert(&CB);
3083    // First, update the barrier ED kept in the separate CEDMap.
3084    auto &CallInED = CEDMap[{&CB, PRE}];
3085    Changed |= mergeInPredecessor(A, CallInED, ED);
3086    CallInED.IsReachingAlignedBarrierOnly = true;
3087    // Next adjust the ED we use for the traversal.
3088    ED.EncounteredNonLocalSideEffect = false;
3089    ED.IsReachedFromAlignedBarrierOnly = true;
3090    // Aligned barrier collection has to come last.
3091    ED.clearAssumeInstAndAlignedBarriers();
3092    ED.addAlignedBarrier(A, CB);
3093    auto &CallOutED = CEDMap[{&CB, POST}];
3094    Changed |= mergeInPredecessor(A, CallOutED, ED);
3095  };
3096
3097  auto *LivenessAA =
3098      A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL);
3099
3100  Function *F = getAnchorScope();
3101  BasicBlock &EntryBB = F->getEntryBlock();
3102  bool IsKernel = omp::isOpenMPKernel(*F);
3103
3104  SmallVector<Instruction *> SyncInstWorklist;
3105  for (auto &RIt : *RPOT) {
3106    BasicBlock &BB = *RIt;
3107
3108    bool IsEntryBB = &BB == &EntryBB;
3109    // TODO: We use local reasoning since we don't have a divergence analysis
3110    // 	     running as well. We could basically allow uniform branches here.
3111    bool AlignedBarrierLastInBlock = IsEntryBB && IsKernel;
3112    bool IsExplicitlyAligned = IsEntryBB && IsKernel;
3113    ExecutionDomainTy ED;
3114    // Propagate "incoming edges" into information about this block.
3115    if (IsEntryBB) {
3116      Changed |= handleCallees(A, ED);
3117    } else {
3118      // For live non-entry blocks we only propagate
3119      // information via live edges.
3120      if (LivenessAA && LivenessAA->isAssumedDead(&BB))
3121        continue;
3122
3123      for (auto *PredBB : predecessors(&BB)) {
3124        if (LivenessAA && LivenessAA->isEdgeDead(PredBB, &BB))
3125          continue;
3126        bool InitialEdgeOnly = isInitialThreadOnlyEdge(
3127            A, dyn_cast<BranchInst>(PredBB->getTerminator()), BB);
3128        mergeInPredecessor(A, ED, BEDMap[PredBB], InitialEdgeOnly);
3129      }
3130    }
3131
3132    // Now we traverse the block, accumulate effects in ED and attach
3133    // information to calls.
3134    for (Instruction &I : BB) {
3135      bool UsedAssumedInformation;
3136      if (A.isAssumedDead(I, *this, LivenessAA, UsedAssumedInformation,
3137                          /* CheckBBLivenessOnly */ false, DepClassTy::OPTIONAL,
3138                          /* CheckForDeadStore */ true))
3139        continue;
3140
3141      // Asummes and "assume-like" (dbg, lifetime, ...) are handled first, the
3142      // former is collected the latter is ignored.
3143      if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
3144        if (auto *AI = dyn_cast_or_null<AssumeInst>(II)) {
3145          ED.addAssumeInst(A, *AI);
3146          continue;
3147        }
3148        // TODO: Should we also collect and delete lifetime markers?
3149        if (II->isAssumeLikeIntrinsic())
3150          continue;
3151      }
3152
3153      if (auto *FI = dyn_cast<FenceInst>(&I)) {
3154        if (!ED.EncounteredNonLocalSideEffect) {
3155          // An aligned fence without non-local side-effects is a no-op.
3156          if (ED.IsReachedFromAlignedBarrierOnly)
3157            continue;
3158          // A non-aligned fence without non-local side-effects is a no-op
3159          // if the ordering only publishes non-local side-effects (or less).
3160          switch (FI->getOrdering()) {
3161          case AtomicOrdering::NotAtomic:
3162            continue;
3163          case AtomicOrdering::Unordered:
3164            continue;
3165          case AtomicOrdering::Monotonic:
3166            continue;
3167          case AtomicOrdering::Acquire:
3168            break;
3169          case AtomicOrdering::Release:
3170            continue;
3171          case AtomicOrdering::AcquireRelease:
3172            break;
3173          case AtomicOrdering::SequentiallyConsistent:
3174            break;
3175          };
3176        }
3177        NonNoOpFences.insert(FI);
3178      }
3179
3180      auto *CB = dyn_cast<CallBase>(&I);
3181      bool IsNoSync = AA::isNoSyncInst(A, I, *this);
3182      bool IsAlignedBarrier =
3183          !IsNoSync && CB &&
3184          AANoSync::isAlignedBarrier(*CB, AlignedBarrierLastInBlock);
3185
3186      AlignedBarrierLastInBlock &= IsNoSync;
3187      IsExplicitlyAligned &= IsNoSync;
3188
3189      // Next we check for calls. Aligned barriers are handled
3190      // explicitly, everything else is kept for the backward traversal and will
3191      // also affect our state.
3192      if (CB) {
3193        if (IsAlignedBarrier) {
3194          HandleAlignedBarrier(*CB, ED);
3195          AlignedBarrierLastInBlock = true;
3196          IsExplicitlyAligned = true;
3197          continue;
3198        }
3199
3200        // Check the pointer(s) of a memory intrinsic explicitly.
3201        if (isa<MemIntrinsic>(&I)) {
3202          if (!ED.EncounteredNonLocalSideEffect &&
3203              AA::isPotentiallyAffectedByBarrier(A, I, *this))
3204            ED.EncounteredNonLocalSideEffect = true;
3205          if (!IsNoSync) {
3206            ED.IsReachedFromAlignedBarrierOnly = false;
3207            SyncInstWorklist.push_back(&I);
3208          }
3209          continue;
3210        }
3211
3212        // Record how we entered the call, then accumulate the effect of the
3213        // call in ED for potential use by the callee.
3214        auto &CallInED = CEDMap[{CB, PRE}];
3215        Changed |= mergeInPredecessor(A, CallInED, ED);
3216
3217        // If we have a sync-definition we can check if it starts/ends in an
3218        // aligned barrier. If we are unsure we assume any sync breaks
3219        // alignment.
3220        Function *Callee = CB->getCalledFunction();
3221        if (!IsNoSync && Callee && !Callee->isDeclaration()) {
3222          const auto *EDAA = A.getAAFor<AAExecutionDomain>(
3223              *this, IRPosition::function(*Callee), DepClassTy::OPTIONAL);
3224          if (EDAA && EDAA->getState().isValidState()) {
3225            const auto &CalleeED = EDAA->getFunctionExecutionDomain();
3226            ED.IsReachedFromAlignedBarrierOnly =
3227                CalleeED.IsReachedFromAlignedBarrierOnly;
3228            AlignedBarrierLastInBlock = ED.IsReachedFromAlignedBarrierOnly;
3229            if (IsNoSync || !CalleeED.IsReachedFromAlignedBarrierOnly)
3230              ED.EncounteredNonLocalSideEffect |=
3231                  CalleeED.EncounteredNonLocalSideEffect;
3232            else
3233              ED.EncounteredNonLocalSideEffect =
3234                  CalleeED.EncounteredNonLocalSideEffect;
3235            if (!CalleeED.IsReachingAlignedBarrierOnly) {
3236              Changed |=
3237                  setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3238              SyncInstWorklist.push_back(&I);
3239            }
3240            if (CalleeED.IsReachedFromAlignedBarrierOnly)
3241              mergeInPredecessorBarriersAndAssumptions(A, ED, CalleeED);
3242            auto &CallOutED = CEDMap[{CB, POST}];
3243            Changed |= mergeInPredecessor(A, CallOutED, ED);
3244            continue;
3245          }
3246        }
3247        if (!IsNoSync) {
3248          ED.IsReachedFromAlignedBarrierOnly = false;
3249          Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3250          SyncInstWorklist.push_back(&I);
3251        }
3252        AlignedBarrierLastInBlock &= ED.IsReachedFromAlignedBarrierOnly;
3253        ED.EncounteredNonLocalSideEffect |= !CB->doesNotAccessMemory();
3254        auto &CallOutED = CEDMap[{CB, POST}];
3255        Changed |= mergeInPredecessor(A, CallOutED, ED);
3256      }
3257
3258      if (!I.mayHaveSideEffects() && !I.mayReadFromMemory())
3259        continue;
3260
3261      // If we have a callee we try to use fine-grained information to
3262      // determine local side-effects.
3263      if (CB) {
3264        const auto *MemAA = A.getAAFor<AAMemoryLocation>(
3265            *this, IRPosition::callsite_function(*CB), DepClassTy::OPTIONAL);
3266
3267        auto AccessPred = [&](const Instruction *I, const Value *Ptr,
3268                              AAMemoryLocation::AccessKind,
3269                              AAMemoryLocation::MemoryLocationsKind) {
3270          return !AA::isPotentiallyAffectedByBarrier(A, {Ptr}, *this, I);
3271        };
3272        if (MemAA && MemAA->getState().isValidState() &&
3273            MemAA->checkForAllAccessesToMemoryKind(
3274                AccessPred, AAMemoryLocation::ALL_LOCATIONS))
3275          continue;
3276      }
3277
3278      auto &InfoCache = A.getInfoCache();
3279      if (!I.mayHaveSideEffects() && InfoCache.isOnlyUsedByAssume(I))
3280        continue;
3281
3282      if (auto *LI = dyn_cast<LoadInst>(&I))
3283        if (LI->hasMetadata(LLVMContext::MD_invariant_load))
3284          continue;
3285
3286      if (!ED.EncounteredNonLocalSideEffect &&
3287          AA::isPotentiallyAffectedByBarrier(A, I, *this))
3288        ED.EncounteredNonLocalSideEffect = true;
3289    }
3290
3291    bool IsEndAndNotReachingAlignedBarriersOnly = false;
3292    if (!isa<UnreachableInst>(BB.getTerminator()) &&
3293        !BB.getTerminator()->getNumSuccessors()) {
3294
3295      Changed |= mergeInPredecessor(A, InterProceduralED, ED);
3296
3297      auto &FnED = BEDMap[nullptr];
3298      if (IsKernel && !IsExplicitlyAligned)
3299        FnED.IsReachingAlignedBarrierOnly = false;
3300      Changed |= mergeInPredecessor(A, FnED, ED);
3301
3302      if (!FnED.IsReachingAlignedBarrierOnly) {
3303        IsEndAndNotReachingAlignedBarriersOnly = true;
3304        SyncInstWorklist.push_back(BB.getTerminator());
3305        auto &BBED = BEDMap[&BB];
3306        Changed |= setAndRecord(BBED.IsReachingAlignedBarrierOnly, false);
3307      }
3308    }
3309
3310    ExecutionDomainTy &StoredED = BEDMap[&BB];
3311    ED.IsReachingAlignedBarrierOnly = StoredED.IsReachingAlignedBarrierOnly &
3312                                      !IsEndAndNotReachingAlignedBarriersOnly;
3313
3314    // Check if we computed anything different as part of the forward
3315    // traversal. We do not take assumptions and aligned barriers into account
3316    // as they do not influence the state we iterate. Backward traversal values
3317    // are handled later on.
3318    if (ED.IsExecutedByInitialThreadOnly !=
3319            StoredED.IsExecutedByInitialThreadOnly ||
3320        ED.IsReachedFromAlignedBarrierOnly !=
3321            StoredED.IsReachedFromAlignedBarrierOnly ||
3322        ED.EncounteredNonLocalSideEffect !=
3323            StoredED.EncounteredNonLocalSideEffect)
3324      Changed = true;
3325
3326    // Update the state with the new value.
3327    StoredED = std::move(ED);
3328  }
3329
3330  // Propagate (non-aligned) sync instruction effects backwards until the
3331  // entry is hit or an aligned barrier.
3332  SmallSetVector<BasicBlock *, 16> Visited;
3333  while (!SyncInstWorklist.empty()) {
3334    Instruction *SyncInst = SyncInstWorklist.pop_back_val();
3335    Instruction *CurInst = SyncInst;
3336    bool HitAlignedBarrierOrKnownEnd = false;
3337    while ((CurInst = CurInst->getPrevNode())) {
3338      auto *CB = dyn_cast<CallBase>(CurInst);
3339      if (!CB)
3340        continue;
3341      auto &CallOutED = CEDMap[{CB, POST}];
3342      Changed |= setAndRecord(CallOutED.IsReachingAlignedBarrierOnly, false);
3343      auto &CallInED = CEDMap[{CB, PRE}];
3344      HitAlignedBarrierOrKnownEnd =
3345          AlignedBarriers.count(CB) || !CallInED.IsReachingAlignedBarrierOnly;
3346      if (HitAlignedBarrierOrKnownEnd)
3347        break;
3348      Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3349    }
3350    if (HitAlignedBarrierOrKnownEnd)
3351      continue;
3352    BasicBlock *SyncBB = SyncInst->getParent();
3353    for (auto *PredBB : predecessors(SyncBB)) {
3354      if (LivenessAA && LivenessAA->isEdgeDead(PredBB, SyncBB))
3355        continue;
3356      if (!Visited.insert(PredBB))
3357        continue;
3358      auto &PredED = BEDMap[PredBB];
3359      if (setAndRecord(PredED.IsReachingAlignedBarrierOnly, false)) {
3360        Changed = true;
3361        SyncInstWorklist.push_back(PredBB->getTerminator());
3362      }
3363    }
3364    if (SyncBB != &EntryBB)
3365      continue;
3366    Changed |=
3367        setAndRecord(InterProceduralED.IsReachingAlignedBarrierOnly, false);
3368  }
3369
3370  return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED;
3371}
3372
3373/// Try to replace memory allocation calls called by a single thread with a
3374/// static buffer of shared memory.
3375struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {
3376  using Base = StateWrapper<BooleanState, AbstractAttribute>;
3377  AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3378
3379  /// Create an abstract attribute view for the position \p IRP.
3380  static AAHeapToShared &createForPosition(const IRPosition &IRP,
3381                                           Attributor &A);
3382
3383  /// Returns true if HeapToShared conversion is assumed to be possible.
3384  virtual bool isAssumedHeapToShared(CallBase &CB) const = 0;
3385
3386  /// Returns true if HeapToShared conversion is assumed and the CB is a
3387  /// callsite to a free operation to be removed.
3388  virtual bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const = 0;
3389
3390  /// See AbstractAttribute::getName().
3391  const std::string getName() const override { return "AAHeapToShared"; }
3392
3393  /// See AbstractAttribute::getIdAddr().
3394  const char *getIdAddr() const override { return &ID; }
3395
3396  /// This function should return true if the type of the \p AA is
3397  /// AAHeapToShared.
3398  static bool classof(const AbstractAttribute *AA) {
3399    return (AA->getIdAddr() == &ID);
3400  }
3401
3402  /// Unique ID (due to the unique address)
3403  static const char ID;
3404};
3405
3406struct AAHeapToSharedFunction : public AAHeapToShared {
3407  AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
3408      : AAHeapToShared(IRP, A) {}
3409
3410  const std::string getAsStr(Attributor *) const override {
3411    return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
3412           " malloc calls eligible.";
3413  }
3414
3415  /// See AbstractAttribute::trackStatistics().
3416  void trackStatistics() const override {}
3417
3418  /// This functions finds free calls that will be removed by the
3419  /// HeapToShared transformation.
3420  void findPotentialRemovedFreeCalls(Attributor &A) {
3421    auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3422    auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3423
3424    PotentialRemovedFreeCalls.clear();
3425    // Update free call users of found malloc calls.
3426    for (CallBase *CB : MallocCalls) {
3427      SmallVector<CallBase *, 4> FreeCalls;
3428      for (auto *U : CB->users()) {
3429        CallBase *C = dyn_cast<CallBase>(U);
3430        if (C && C->getCalledFunction() == FreeRFI.Declaration)
3431          FreeCalls.push_back(C);
3432      }
3433
3434      if (FreeCalls.size() != 1)
3435        continue;
3436
3437      PotentialRemovedFreeCalls.insert(FreeCalls.front());
3438    }
3439  }
3440
3441  void initialize(Attributor &A) override {
3442    if (DisableOpenMPOptDeglobalization) {
3443      indicatePessimisticFixpoint();
3444      return;
3445    }
3446
3447    auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3448    auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3449    if (!RFI.Declaration)
3450      return;
3451
3452    Attributor::SimplifictionCallbackTy SCB =
3453        [](const IRPosition &, const AbstractAttribute *,
3454           bool &) -> std::optional<Value *> { return nullptr; };
3455
3456    Function *F = getAnchorScope();
3457    for (User *U : RFI.Declaration->users())
3458      if (CallBase *CB = dyn_cast<CallBase>(U)) {
3459        if (CB->getFunction() != F)
3460          continue;
3461        MallocCalls.insert(CB);
3462        A.registerSimplificationCallback(IRPosition::callsite_returned(*CB),
3463                                         SCB);
3464      }
3465
3466    findPotentialRemovedFreeCalls(A);
3467  }
3468
3469  bool isAssumedHeapToShared(CallBase &CB) const override {
3470    return isValidState() && MallocCalls.count(&CB);
3471  }
3472
3473  bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const override {
3474    return isValidState() && PotentialRemovedFreeCalls.count(&CB);
3475  }
3476
3477  ChangeStatus manifest(Attributor &A) override {
3478    if (MallocCalls.empty())
3479      return ChangeStatus::UNCHANGED;
3480
3481    auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3482    auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3483
3484    Function *F = getAnchorScope();
3485    auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this,
3486                                            DepClassTy::OPTIONAL);
3487
3488    ChangeStatus Changed = ChangeStatus::UNCHANGED;
3489    for (CallBase *CB : MallocCalls) {
3490      // Skip replacing this if HeapToStack has already claimed it.
3491      if (HS && HS->isAssumedHeapToStack(*CB))
3492        continue;
3493
3494      // Find the unique free call to remove it.
3495      SmallVector<CallBase *, 4> FreeCalls;
3496      for (auto *U : CB->users()) {
3497        CallBase *C = dyn_cast<CallBase>(U);
3498        if (C && C->getCalledFunction() == FreeCall.Declaration)
3499          FreeCalls.push_back(C);
3500      }
3501      if (FreeCalls.size() != 1)
3502        continue;
3503
3504      auto *AllocSize = cast<ConstantInt>(CB->getArgOperand(0));
3505
3506      if (AllocSize->getZExtValue() + SharedMemoryUsed > SharedMemoryLimit) {
3507        LLVM_DEBUG(dbgs() << TAG << "Cannot replace call " << *CB
3508                          << " with shared memory."
3509                          << " Shared memory usage is limited to "
3510                          << SharedMemoryLimit << " bytes\n");
3511        continue;
3512      }
3513
3514      LLVM_DEBUG(dbgs() << TAG << "Replace globalization call " << *CB
3515                        << " with " << AllocSize->getZExtValue()
3516                        << " bytes of shared memory\n");
3517
3518      // Create a new shared memory buffer of the same size as the allocation
3519      // and replace all the uses of the original allocation with it.
3520      Module *M = CB->getModule();
3521      Type *Int8Ty = Type::getInt8Ty(M->getContext());
3522      Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
3523      auto *SharedMem = new GlobalVariable(
3524          *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
3525          PoisonValue::get(Int8ArrTy), CB->getName() + "_shared", nullptr,
3526          GlobalValue::NotThreadLocal,
3527          static_cast<unsigned>(AddressSpace::Shared));
3528      auto *NewBuffer =
3529          ConstantExpr::getPointerCast(SharedMem, Int8Ty->getPointerTo());
3530
3531      auto Remark = [&](OptimizationRemark OR) {
3532        return OR << "Replaced globalized variable with "
3533                  << ore::NV("SharedMemory", AllocSize->getZExtValue())
3534                  << (AllocSize->isOne() ? " byte " : " bytes ")
3535                  << "of shared memory.";
3536      };
3537      A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark);
3538
3539      MaybeAlign Alignment = CB->getRetAlign();
3540      assert(Alignment &&
3541             "HeapToShared on allocation without alignment attribute");
3542      SharedMem->setAlignment(*Alignment);
3543
3544      A.changeAfterManifest(IRPosition::callsite_returned(*CB), *NewBuffer);
3545      A.deleteAfterManifest(*CB);
3546      A.deleteAfterManifest(*FreeCalls.front());
3547
3548      SharedMemoryUsed += AllocSize->getZExtValue();
3549      NumBytesMovedToSharedMemory = SharedMemoryUsed;
3550      Changed = ChangeStatus::CHANGED;
3551    }
3552
3553    return Changed;
3554  }
3555
3556  ChangeStatus updateImpl(Attributor &A) override {
3557    if (MallocCalls.empty())
3558      return indicatePessimisticFixpoint();
3559    auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3560    auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3561    if (!RFI.Declaration)
3562      return ChangeStatus::UNCHANGED;
3563
3564    Function *F = getAnchorScope();
3565
3566    auto NumMallocCalls = MallocCalls.size();
3567
3568    // Only consider malloc calls executed by a single thread with a constant.
3569    for (User *U : RFI.Declaration->users()) {
3570      if (CallBase *CB = dyn_cast<CallBase>(U)) {
3571        if (CB->getCaller() != F)
3572          continue;
3573        if (!MallocCalls.count(CB))
3574          continue;
3575        if (!isa<ConstantInt>(CB->getArgOperand(0))) {
3576          MallocCalls.remove(CB);
3577          continue;
3578        }
3579        const auto *ED = A.getAAFor<AAExecutionDomain>(
3580            *this, IRPosition::function(*F), DepClassTy::REQUIRED);
3581        if (!ED || !ED->isExecutedByInitialThreadOnly(*CB))
3582          MallocCalls.remove(CB);
3583      }
3584    }
3585
3586    findPotentialRemovedFreeCalls(A);
3587
3588    if (NumMallocCalls != MallocCalls.size())
3589      return ChangeStatus::CHANGED;
3590
3591    return ChangeStatus::UNCHANGED;
3592  }
3593
3594  /// Collection of all malloc calls in a function.
3595  SmallSetVector<CallBase *, 4> MallocCalls;
3596  /// Collection of potentially removed free calls in a function.
3597  SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls;
3598  /// The total amount of shared memory that has been used for HeapToShared.
3599  unsigned SharedMemoryUsed = 0;
3600};
3601
3602struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
3603  using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
3604  AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3605
3606  /// The callee value is tracked beyond a simple stripPointerCasts, so we allow
3607  /// unknown callees.
3608  static bool requiresCalleeForCallBase() { return false; }
3609
3610  /// Statistics are tracked as part of manifest for now.
3611  void trackStatistics() const override {}
3612
3613  /// See AbstractAttribute::getAsStr()
3614  const std::string getAsStr(Attributor *) const override {
3615    if (!isValidState())
3616      return "<invalid>";
3617    return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"
3618                                                            : "generic") +
3619           std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]"
3620                                                               : "") +
3621           std::string(" #PRs: ") +
3622           (ReachedKnownParallelRegions.isValidState()
3623                ? std::to_string(ReachedKnownParallelRegions.size())
3624                : "<invalid>") +
3625           ", #Unknown PRs: " +
3626           (ReachedUnknownParallelRegions.isValidState()
3627                ? std::to_string(ReachedUnknownParallelRegions.size())
3628                : "<invalid>") +
3629           ", #Reaching Kernels: " +
3630           (ReachingKernelEntries.isValidState()
3631                ? std::to_string(ReachingKernelEntries.size())
3632                : "<invalid>") +
3633           ", #ParLevels: " +
3634           (ParallelLevels.isValidState()
3635                ? std::to_string(ParallelLevels.size())
3636                : "<invalid>") +
3637           ", NestedPar: " + (NestedParallelism ? "yes" : "no");
3638  }
3639
3640  /// Create an abstract attribute biew for the position \p IRP.
3641  static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
3642
3643  /// See AbstractAttribute::getName()
3644  const std::string getName() const override { return "AAKernelInfo"; }
3645
3646  /// See AbstractAttribute::getIdAddr()
3647  const char *getIdAddr() const override { return &ID; }
3648
3649  /// This function should return true if the type of the \p AA is AAKernelInfo
3650  static bool classof(const AbstractAttribute *AA) {
3651    return (AA->getIdAddr() == &ID);
3652  }
3653
3654  static const char ID;
3655};
3656
3657/// The function kernel info abstract attribute, basically, what can we say
3658/// about a function with regards to the KernelInfoState.
3659struct AAKernelInfoFunction : AAKernelInfo {
3660  AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
3661      : AAKernelInfo(IRP, A) {}
3662
3663  SmallPtrSet<Instruction *, 4> GuardedInstructions;
3664
3665  SmallPtrSetImpl<Instruction *> &getGuardedInstructions() {
3666    return GuardedInstructions;
3667  }
3668
3669  void setConfigurationOfKernelEnvironment(ConstantStruct *ConfigC) {
3670    Constant *NewKernelEnvC = ConstantFoldInsertValueInstruction(
3671        KernelEnvC, ConfigC, {KernelInfo::ConfigurationIdx});
3672    assert(NewKernelEnvC && "Failed to create new kernel environment");
3673    KernelEnvC = cast<ConstantStruct>(NewKernelEnvC);
3674  }
3675
3676#define KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MEMBER)                        \
3677  void set##MEMBER##OfKernelEnvironment(ConstantInt *NewVal) {                 \
3678    ConstantStruct *ConfigC =                                                  \
3679        KernelInfo::getConfigurationFromKernelEnvironment(KernelEnvC);         \
3680    Constant *NewConfigC = ConstantFoldInsertValueInstruction(                 \
3681        ConfigC, NewVal, {KernelInfo::MEMBER##Idx});                           \
3682    assert(NewConfigC && "Failed to create new configuration environment");    \
3683    setConfigurationOfKernelEnvironment(cast<ConstantStruct>(NewConfigC));     \
3684  }
3685
3686  KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(UseGenericStateMachine)
3687  KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MayUseNestedParallelism)
3688  KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(ExecMode)
3689  KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MinThreads)
3690  KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MaxThreads)
3691  KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MinTeams)
3692  KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MaxTeams)
3693
3694#undef KERNEL_ENVIRONMENT_CONFIGURATION_SETTER
3695
3696  /// See AbstractAttribute::initialize(...).
3697  void initialize(Attributor &A) override {
3698    // This is a high-level transform that might change the constant arguments
3699    // of the init and dinit calls. We need to tell the Attributor about this
3700    // to avoid other parts using the current constant value for simpliication.
3701    auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3702
3703    Function *Fn = getAnchorScope();
3704
3705    OMPInformationCache::RuntimeFunctionInfo &InitRFI =
3706        OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
3707    OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
3708        OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
3709
3710    // For kernels we perform more initialization work, first we find the init
3711    // and deinit calls.
3712    auto StoreCallBase = [](Use &U,
3713                            OMPInformationCache::RuntimeFunctionInfo &RFI,
3714                            CallBase *&Storage) {
3715      CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
3716      assert(CB &&
3717             "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
3718      assert(!Storage &&
3719             "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
3720      Storage = CB;
3721      return false;
3722    };
3723    InitRFI.foreachUse(
3724        [&](Use &U, Function &) {
3725          StoreCallBase(U, InitRFI, KernelInitCB);
3726          return false;
3727        },
3728        Fn);
3729    DeinitRFI.foreachUse(
3730        [&](Use &U, Function &) {
3731          StoreCallBase(U, DeinitRFI, KernelDeinitCB);
3732          return false;
3733        },
3734        Fn);
3735
3736    // Ignore kernels without initializers such as global constructors.
3737    if (!KernelInitCB || !KernelDeinitCB)
3738      return;
3739
3740    // Add itself to the reaching kernel and set IsKernelEntry.
3741    ReachingKernelEntries.insert(Fn);
3742    IsKernelEntry = true;
3743
3744    KernelEnvC =
3745        KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB);
3746    GlobalVariable *KernelEnvGV =
3747        KernelInfo::getKernelEnvironementGVFromKernelInitCB(KernelInitCB);
3748
3749    Attributor::GlobalVariableSimplifictionCallbackTy
3750        KernelConfigurationSimplifyCB =
3751            [&](const GlobalVariable &GV, const AbstractAttribute *AA,
3752                bool &UsedAssumedInformation) -> std::optional<Constant *> {
3753      if (!isAtFixpoint()) {
3754        if (!AA)
3755          return nullptr;
3756        UsedAssumedInformation = true;
3757        A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
3758      }
3759      return KernelEnvC;
3760    };
3761
3762    A.registerGlobalVariableSimplificationCallback(
3763        *KernelEnvGV, KernelConfigurationSimplifyCB);
3764
3765    // Check if we know we are in SPMD-mode already.
3766    ConstantInt *ExecModeC =
3767        KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
3768    ConstantInt *AssumedExecModeC = ConstantInt::get(
3769        ExecModeC->getIntegerType(),
3770        ExecModeC->getSExtValue() | OMP_TGT_EXEC_MODE_GENERIC_SPMD);
3771    if (ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD)
3772      SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3773    else if (DisableOpenMPOptSPMDization)
3774      // This is a generic region but SPMDization is disabled so stop
3775      // tracking.
3776      SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3777    else
3778      setExecModeOfKernelEnvironment(AssumedExecModeC);
3779
3780    const Triple T(Fn->getParent()->getTargetTriple());
3781    auto *Int32Ty = Type::getInt32Ty(Fn->getContext());
3782    auto [MinThreads, MaxThreads] =
3783        OpenMPIRBuilder::readThreadBoundsForKernel(T, *Fn);
3784    if (MinThreads)
3785      setMinThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinThreads));
3786    if (MaxThreads)
3787      setMaxThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxThreads));
3788    auto [MinTeams, MaxTeams] =
3789        OpenMPIRBuilder::readTeamBoundsForKernel(T, *Fn);
3790    if (MinTeams)
3791      setMinTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinTeams));
3792    if (MaxTeams)
3793      setMaxTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxTeams));
3794
3795    ConstantInt *MayUseNestedParallelismC =
3796        KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(KernelEnvC);
3797    ConstantInt *AssumedMayUseNestedParallelismC = ConstantInt::get(
3798        MayUseNestedParallelismC->getIntegerType(), NestedParallelism);
3799    setMayUseNestedParallelismOfKernelEnvironment(
3800        AssumedMayUseNestedParallelismC);
3801
3802    if (!DisableOpenMPOptStateMachineRewrite) {
3803      ConstantInt *UseGenericStateMachineC =
3804          KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3805              KernelEnvC);
3806      ConstantInt *AssumedUseGenericStateMachineC =
3807          ConstantInt::get(UseGenericStateMachineC->getIntegerType(), false);
3808      setUseGenericStateMachineOfKernelEnvironment(
3809          AssumedUseGenericStateMachineC);
3810    }
3811
3812    // Register virtual uses of functions we might need to preserve.
3813    auto RegisterVirtualUse = [&](RuntimeFunction RFKind,
3814                                  Attributor::VirtualUseCallbackTy &CB) {
3815      if (!OMPInfoCache.RFIs[RFKind].Declaration)
3816        return;
3817      A.registerVirtualUseCallback(*OMPInfoCache.RFIs[RFKind].Declaration, CB);
3818    };
3819
3820    // Add a dependence to ensure updates if the state changes.
3821    auto AddDependence = [](Attributor &A, const AAKernelInfo *KI,
3822                            const AbstractAttribute *QueryingAA) {
3823      if (QueryingAA) {
3824        A.recordDependence(*KI, *QueryingAA, DepClassTy::OPTIONAL);
3825      }
3826      return true;
3827    };
3828
3829    Attributor::VirtualUseCallbackTy CustomStateMachineUseCB =
3830        [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3831          // Whenever we create a custom state machine we will insert calls to
3832          // __kmpc_get_hardware_num_threads_in_block,
3833          // __kmpc_get_warp_size,
3834          // __kmpc_barrier_simple_generic,
3835          // __kmpc_kernel_parallel, and
3836          // __kmpc_kernel_end_parallel.
3837          // Not needed if we are on track for SPMDzation.
3838          if (SPMDCompatibilityTracker.isValidState())
3839            return AddDependence(A, this, QueryingAA);
3840          // Not needed if we can't rewrite due to an invalid state.
3841          if (!ReachedKnownParallelRegions.isValidState())
3842            return AddDependence(A, this, QueryingAA);
3843          return false;
3844        };
3845
3846    // Not needed if we are pre-runtime merge.
3847    if (!KernelInitCB->getCalledFunction()->isDeclaration()) {
3848      RegisterVirtualUse(OMPRTL___kmpc_get_hardware_num_threads_in_block,
3849                         CustomStateMachineUseCB);
3850      RegisterVirtualUse(OMPRTL___kmpc_get_warp_size, CustomStateMachineUseCB);
3851      RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_generic,
3852                         CustomStateMachineUseCB);
3853      RegisterVirtualUse(OMPRTL___kmpc_kernel_parallel,
3854                         CustomStateMachineUseCB);
3855      RegisterVirtualUse(OMPRTL___kmpc_kernel_end_parallel,
3856                         CustomStateMachineUseCB);
3857    }
3858
3859    // If we do not perform SPMDzation we do not need the virtual uses below.
3860    if (SPMDCompatibilityTracker.isAtFixpoint())
3861      return;
3862
3863    Attributor::VirtualUseCallbackTy HWThreadIdUseCB =
3864        [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3865          // Whenever we perform SPMDzation we will insert
3866          // __kmpc_get_hardware_thread_id_in_block calls.
3867          if (!SPMDCompatibilityTracker.isValidState())
3868            return AddDependence(A, this, QueryingAA);
3869          return false;
3870        };
3871    RegisterVirtualUse(OMPRTL___kmpc_get_hardware_thread_id_in_block,
3872                       HWThreadIdUseCB);
3873
3874    Attributor::VirtualUseCallbackTy SPMDBarrierUseCB =
3875        [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3876          // Whenever we perform SPMDzation with guarding we will insert
3877          // __kmpc_simple_barrier_spmd calls. If SPMDzation failed, there is
3878          // nothing to guard, or there are no parallel regions, we don't need
3879          // the calls.
3880          if (!SPMDCompatibilityTracker.isValidState())
3881            return AddDependence(A, this, QueryingAA);
3882          if (SPMDCompatibilityTracker.empty())
3883            return AddDependence(A, this, QueryingAA);
3884          if (!mayContainParallelRegion())
3885            return AddDependence(A, this, QueryingAA);
3886          return false;
3887        };
3888    RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_spmd, SPMDBarrierUseCB);
3889  }
3890
3891  /// Sanitize the string \p S such that it is a suitable global symbol name.
3892  static std::string sanitizeForGlobalName(std::string S) {
3893    std::replace_if(
3894        S.begin(), S.end(),
3895        [](const char C) {
3896          return !((C >= 'a' && C <= 'z') || (C >= 'A' && C <= 'Z') ||
3897                   (C >= '0' && C <= '9') || C == '_');
3898        },
3899        '.');
3900    return S;
3901  }
3902
3903  /// Modify the IR based on the KernelInfoState as the fixpoint iteration is
3904  /// finished now.
3905  ChangeStatus manifest(Attributor &A) override {
3906    // If we are not looking at a kernel with __kmpc_target_init and
3907    // __kmpc_target_deinit call we cannot actually manifest the information.
3908    if (!KernelInitCB || !KernelDeinitCB)
3909      return ChangeStatus::UNCHANGED;
3910
3911    ChangeStatus Changed = ChangeStatus::UNCHANGED;
3912
3913    bool HasBuiltStateMachine = true;
3914    if (!changeToSPMDMode(A, Changed)) {
3915      if (!KernelInitCB->getCalledFunction()->isDeclaration())
3916        HasBuiltStateMachine = buildCustomStateMachine(A, Changed);
3917      else
3918        HasBuiltStateMachine = false;
3919    }
3920
3921    // We need to reset KernelEnvC if specific rewriting is not done.
3922    ConstantStruct *ExistingKernelEnvC =
3923        KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB);
3924    ConstantInt *OldUseGenericStateMachineVal =
3925        KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3926            ExistingKernelEnvC);
3927    if (!HasBuiltStateMachine)
3928      setUseGenericStateMachineOfKernelEnvironment(
3929          OldUseGenericStateMachineVal);
3930
3931    // At last, update the KernelEnvc
3932    GlobalVariable *KernelEnvGV =
3933        KernelInfo::getKernelEnvironementGVFromKernelInitCB(KernelInitCB);
3934    if (KernelEnvGV->getInitializer() != KernelEnvC) {
3935      KernelEnvGV->setInitializer(KernelEnvC);
3936      Changed = ChangeStatus::CHANGED;
3937    }
3938
3939    return Changed;
3940  }
3941
3942  void insertInstructionGuardsHelper(Attributor &A) {
3943    auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3944
3945    auto CreateGuardedRegion = [&](Instruction *RegionStartI,
3946                                   Instruction *RegionEndI) {
3947      LoopInfo *LI = nullptr;
3948      DominatorTree *DT = nullptr;
3949      MemorySSAUpdater *MSU = nullptr;
3950      using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3951
3952      BasicBlock *ParentBB = RegionStartI->getParent();
3953      Function *Fn = ParentBB->getParent();
3954      Module &M = *Fn->getParent();
3955
3956      // Create all the blocks and logic.
3957      // ParentBB:
3958      //    goto RegionCheckTidBB
3959      // RegionCheckTidBB:
3960      //    Tid = __kmpc_hardware_thread_id()
3961      //    if (Tid != 0)
3962      //        goto RegionBarrierBB
3963      // RegionStartBB:
3964      //    <execute instructions guarded>
3965      //    goto RegionEndBB
3966      // RegionEndBB:
3967      //    <store escaping values to shared mem>
3968      //    goto RegionBarrierBB
3969      //  RegionBarrierBB:
3970      //    __kmpc_simple_barrier_spmd()
3971      //    // second barrier is omitted if lacking escaping values.
3972      //    <load escaping values from shared mem>
3973      //    __kmpc_simple_barrier_spmd()
3974      //    goto RegionExitBB
3975      // RegionExitBB:
3976      //    <execute rest of instructions>
3977
3978      BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(),
3979                                           DT, LI, MSU, "region.guarded.end");
3980      BasicBlock *RegionBarrierBB =
3981          SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI,
3982                     MSU, "region.barrier");
3983      BasicBlock *RegionExitBB =
3984          SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(),
3985                     DT, LI, MSU, "region.exit");
3986      BasicBlock *RegionStartBB =
3987          SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded");
3988
3989      assert(ParentBB->getUniqueSuccessor() == RegionStartBB &&
3990             "Expected a different CFG");
3991
3992      BasicBlock *RegionCheckTidBB = SplitBlock(
3993          ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid");
3994
3995      // Register basic blocks with the Attributor.
3996      A.registerManifestAddedBasicBlock(*RegionEndBB);
3997      A.registerManifestAddedBasicBlock(*RegionBarrierBB);
3998      A.registerManifestAddedBasicBlock(*RegionExitBB);
3999      A.registerManifestAddedBasicBlock(*RegionStartBB);
4000      A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
4001
4002      bool HasBroadcastValues = false;
4003      // Find escaping outputs from the guarded region to outside users and
4004      // broadcast their values to them.
4005      for (Instruction &I : *RegionStartBB) {
4006        SmallVector<Use *, 4> OutsideUses;
4007        for (Use &U : I.uses()) {
4008          Instruction &UsrI = *cast<Instruction>(U.getUser());
4009          if (UsrI.getParent() != RegionStartBB)
4010            OutsideUses.push_back(&U);
4011        }
4012
4013        if (OutsideUses.empty())
4014          continue;
4015
4016        HasBroadcastValues = true;
4017
4018        // Emit a global variable in shared memory to store the broadcasted
4019        // value.
4020        auto *SharedMem = new GlobalVariable(
4021            M, I.getType(), /* IsConstant */ false,
4022            GlobalValue::InternalLinkage, UndefValue::get(I.getType()),
4023            sanitizeForGlobalName(
4024                (I.getName() + ".guarded.output.alloc").str()),
4025            nullptr, GlobalValue::NotThreadLocal,
4026            static_cast<unsigned>(AddressSpace::Shared));
4027
4028        // Emit a store instruction to update the value.
4029        new StoreInst(&I, SharedMem, RegionEndBB->getTerminator());
4030
4031        LoadInst *LoadI = new LoadInst(I.getType(), SharedMem,
4032                                       I.getName() + ".guarded.output.load",
4033                                       RegionBarrierBB->getTerminator());
4034
4035        // Emit a load instruction and replace uses of the output value.
4036        for (Use *U : OutsideUses)
4037          A.changeUseAfterManifest(*U, *LoadI);
4038      }
4039
4040      auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4041
4042      // Go to tid check BB in ParentBB.
4043      const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
4044      ParentBB->getTerminator()->eraseFromParent();
4045      OpenMPIRBuilder::LocationDescription Loc(
4046          InsertPointTy(ParentBB, ParentBB->end()), DL);
4047      OMPInfoCache.OMPBuilder.updateToLocation(Loc);
4048      uint32_t SrcLocStrSize;
4049      auto *SrcLocStr =
4050          OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4051      Value *Ident =
4052          OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4053      BranchInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL);
4054
4055      // Add check for Tid in RegionCheckTidBB
4056      RegionCheckTidBB->getTerminator()->eraseFromParent();
4057      OpenMPIRBuilder::LocationDescription LocRegionCheckTid(
4058          InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL);
4059      OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
4060      FunctionCallee HardwareTidFn =
4061          OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4062              M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4063      CallInst *Tid =
4064          OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
4065      Tid->setDebugLoc(DL);
4066      OMPInfoCache.setCallingConvention(HardwareTidFn, Tid);
4067      Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
4068      OMPInfoCache.OMPBuilder.Builder
4069          .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
4070          ->setDebugLoc(DL);
4071
4072      // First barrier for synchronization, ensures main thread has updated
4073      // values.
4074      FunctionCallee BarrierFn =
4075          OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4076              M, OMPRTL___kmpc_barrier_simple_spmd);
4077      OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
4078          RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));
4079      CallInst *Barrier =
4080          OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid});
4081      Barrier->setDebugLoc(DL);
4082      OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4083
4084      // Second barrier ensures workers have read broadcast values.
4085      if (HasBroadcastValues) {
4086        CallInst *Barrier = CallInst::Create(BarrierFn, {Ident, Tid}, "",
4087                                             RegionBarrierBB->getTerminator());
4088        Barrier->setDebugLoc(DL);
4089        OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4090      }
4091    };
4092
4093    auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
4094    SmallPtrSet<BasicBlock *, 8> Visited;
4095    for (Instruction *GuardedI : SPMDCompatibilityTracker) {
4096      BasicBlock *BB = GuardedI->getParent();
4097      if (!Visited.insert(BB).second)
4098        continue;
4099
4100      SmallVector<std::pair<Instruction *, Instruction *>> Reorders;
4101      Instruction *LastEffect = nullptr;
4102      BasicBlock::reverse_iterator IP = BB->rbegin(), IPEnd = BB->rend();
4103      while (++IP != IPEnd) {
4104        if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory())
4105          continue;
4106        Instruction *I = &*IP;
4107        if (OpenMPOpt::getCallIfRegularCall(*I, &AllocSharedRFI))
4108          continue;
4109        if (!I->user_empty() || !SPMDCompatibilityTracker.contains(I)) {
4110          LastEffect = nullptr;
4111          continue;
4112        }
4113        if (LastEffect)
4114          Reorders.push_back({I, LastEffect});
4115        LastEffect = &*IP;
4116      }
4117      for (auto &Reorder : Reorders)
4118        Reorder.first->moveBefore(Reorder.second);
4119    }
4120
4121    SmallVector<std::pair<Instruction *, Instruction *>, 4> GuardedRegions;
4122
4123    for (Instruction *GuardedI : SPMDCompatibilityTracker) {
4124      BasicBlock *BB = GuardedI->getParent();
4125      auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
4126          IRPosition::function(*GuardedI->getFunction()), nullptr,
4127          DepClassTy::NONE);
4128      assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
4129      auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
4130      // Continue if instruction is already guarded.
4131      if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
4132        continue;
4133
4134      Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr;
4135      for (Instruction &I : *BB) {
4136        // If instruction I needs to be guarded update the guarded region
4137        // bounds.
4138        if (SPMDCompatibilityTracker.contains(&I)) {
4139          CalleeAAFunction.getGuardedInstructions().insert(&I);
4140          if (GuardedRegionStart)
4141            GuardedRegionEnd = &I;
4142          else
4143            GuardedRegionStart = GuardedRegionEnd = &I;
4144
4145          continue;
4146        }
4147
4148        // Instruction I does not need guarding, store
4149        // any region found and reset bounds.
4150        if (GuardedRegionStart) {
4151          GuardedRegions.push_back(
4152              std::make_pair(GuardedRegionStart, GuardedRegionEnd));
4153          GuardedRegionStart = nullptr;
4154          GuardedRegionEnd = nullptr;
4155        }
4156      }
4157    }
4158
4159    for (auto &GR : GuardedRegions)
4160      CreateGuardedRegion(GR.first, GR.second);
4161  }
4162
4163  void forceSingleThreadPerWorkgroupHelper(Attributor &A) {
4164    // Only allow 1 thread per workgroup to continue executing the user code.
4165    //
4166    //     InitCB = __kmpc_target_init(...)
4167    //     ThreadIdInBlock = __kmpc_get_hardware_thread_id_in_block();
4168    //     if (ThreadIdInBlock != 0) return;
4169    // UserCode:
4170    //     // user code
4171    //
4172    auto &Ctx = getAnchorValue().getContext();
4173    Function *Kernel = getAssociatedFunction();
4174    assert(Kernel && "Expected an associated function!");
4175
4176    // Create block for user code to branch to from initial block.
4177    BasicBlock *InitBB = KernelInitCB->getParent();
4178    BasicBlock *UserCodeBB = InitBB->splitBasicBlock(
4179        KernelInitCB->getNextNode(), "main.thread.user_code");
4180    BasicBlock *ReturnBB =
4181        BasicBlock::Create(Ctx, "exit.threads", Kernel, UserCodeBB);
4182
4183    // Register blocks with attributor:
4184    A.registerManifestAddedBasicBlock(*InitBB);
4185    A.registerManifestAddedBasicBlock(*UserCodeBB);
4186    A.registerManifestAddedBasicBlock(*ReturnBB);
4187
4188    // Debug location:
4189    const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4190    ReturnInst::Create(Ctx, ReturnBB)->setDebugLoc(DLoc);
4191    InitBB->getTerminator()->eraseFromParent();
4192
4193    // Prepare call to OMPRTL___kmpc_get_hardware_thread_id_in_block.
4194    Module &M = *Kernel->getParent();
4195    auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4196    FunctionCallee ThreadIdInBlockFn =
4197        OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4198            M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4199
4200    // Get thread ID in block.
4201    CallInst *ThreadIdInBlock =
4202        CallInst::Create(ThreadIdInBlockFn, "thread_id.in.block", InitBB);
4203    OMPInfoCache.setCallingConvention(ThreadIdInBlockFn, ThreadIdInBlock);
4204    ThreadIdInBlock->setDebugLoc(DLoc);
4205
4206    // Eliminate all threads in the block with ID not equal to 0:
4207    Instruction *IsMainThread =
4208        ICmpInst::Create(ICmpInst::ICmp, CmpInst::ICMP_NE, ThreadIdInBlock,
4209                         ConstantInt::get(ThreadIdInBlock->getType(), 0),
4210                         "thread.is_main", InitBB);
4211    IsMainThread->setDebugLoc(DLoc);
4212    BranchInst::Create(ReturnBB, UserCodeBB, IsMainThread, InitBB);
4213  }
4214
4215  bool changeToSPMDMode(Attributor &A, ChangeStatus &Changed) {
4216    auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4217
4218    // We cannot change to SPMD mode if the runtime functions aren't availible.
4219    if (!OMPInfoCache.runtimeFnsAvailable(
4220            {OMPRTL___kmpc_get_hardware_thread_id_in_block,
4221             OMPRTL___kmpc_barrier_simple_spmd}))
4222      return false;
4223
4224    if (!SPMDCompatibilityTracker.isAssumed()) {
4225      for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
4226        if (!NonCompatibleI)
4227          continue;
4228
4229        // Skip diagnostics on calls to known OpenMP runtime functions for now.
4230        if (auto *CB = dyn_cast<CallBase>(NonCompatibleI))
4231          if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
4232            continue;
4233
4234        auto Remark = [&](OptimizationRemarkAnalysis ORA) {
4235          ORA << "Value has potential side effects preventing SPMD-mode "
4236                 "execution";
4237          if (isa<CallBase>(NonCompatibleI)) {
4238            ORA << ". Add `__attribute__((assume(\"ompx_spmd_amenable\")))` to "
4239                   "the called function to override";
4240          }
4241          return ORA << ".";
4242        };
4243        A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121",
4244                                                 Remark);
4245
4246        LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: "
4247                          << *NonCompatibleI << "\n");
4248      }
4249
4250      return false;
4251    }
4252
4253    // Get the actual kernel, could be the caller of the anchor scope if we have
4254    // a debug wrapper.
4255    Function *Kernel = getAnchorScope();
4256    if (Kernel->hasLocalLinkage()) {
4257      assert(Kernel->hasOneUse() && "Unexpected use of debug kernel wrapper.");
4258      auto *CB = cast<CallBase>(Kernel->user_back());
4259      Kernel = CB->getCaller();
4260    }
4261    assert(omp::isOpenMPKernel(*Kernel) && "Expected kernel function!");
4262
4263    // Check if the kernel is already in SPMD mode, if so, return success.
4264    ConstantStruct *ExistingKernelEnvC =
4265        KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB);
4266    auto *ExecModeC =
4267        KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4268    const int8_t ExecModeVal = ExecModeC->getSExtValue();
4269    if (ExecModeVal != OMP_TGT_EXEC_MODE_GENERIC)
4270      return true;
4271
4272    // We will now unconditionally modify the IR, indicate a change.
4273    Changed = ChangeStatus::CHANGED;
4274
4275    // Do not use instruction guards when no parallel is present inside
4276    // the target region.
4277    if (mayContainParallelRegion())
4278      insertInstructionGuardsHelper(A);
4279    else
4280      forceSingleThreadPerWorkgroupHelper(A);
4281
4282    // Adjust the global exec mode flag that tells the runtime what mode this
4283    // kernel is executed in.
4284    assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC &&
4285           "Initially non-SPMD kernel has SPMD exec mode!");
4286    setExecModeOfKernelEnvironment(
4287        ConstantInt::get(ExecModeC->getIntegerType(),
4288                         ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));
4289
4290    ++NumOpenMPTargetRegionKernelsSPMD;
4291
4292    auto Remark = [&](OptimizationRemark OR) {
4293      return OR << "Transformed generic-mode kernel to SPMD-mode.";
4294    };
4295    A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP120", Remark);
4296    return true;
4297  };
4298
4299  bool buildCustomStateMachine(Attributor &A, ChangeStatus &Changed) {
4300    // If we have disabled state machine rewrites, don't make a custom one
4301    if (DisableOpenMPOptStateMachineRewrite)
4302      return false;
4303
4304    // Don't rewrite the state machine if we are not in a valid state.
4305    if (!ReachedKnownParallelRegions.isValidState())
4306      return false;
4307
4308    auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4309    if (!OMPInfoCache.runtimeFnsAvailable(
4310            {OMPRTL___kmpc_get_hardware_num_threads_in_block,
4311             OMPRTL___kmpc_get_warp_size, OMPRTL___kmpc_barrier_simple_generic,
4312             OMPRTL___kmpc_kernel_parallel, OMPRTL___kmpc_kernel_end_parallel}))
4313      return false;
4314
4315    ConstantStruct *ExistingKernelEnvC =
4316        KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB);
4317
4318    // Check if the current configuration is non-SPMD and generic state machine.
4319    // If we already have SPMD mode or a custom state machine we do not need to
4320    // go any further. If it is anything but a constant something is weird and
4321    // we give up.
4322    ConstantInt *UseStateMachineC =
4323        KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4324            ExistingKernelEnvC);
4325    ConstantInt *ModeC =
4326        KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4327
4328    // If we are stuck with generic mode, try to create a custom device (=GPU)
4329    // state machine which is specialized for the parallel regions that are
4330    // reachable by the kernel.
4331    if (UseStateMachineC->isZero() ||
4332        (ModeC->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD))
4333      return false;
4334
4335    Changed = ChangeStatus::CHANGED;
4336
4337    // If not SPMD mode, indicate we use a custom state machine now.
4338    setUseGenericStateMachineOfKernelEnvironment(
4339        ConstantInt::get(UseStateMachineC->getIntegerType(), false));
4340
4341    // If we don't actually need a state machine we are done here. This can
4342    // happen if there simply are no parallel regions. In the resulting kernel
4343    // all worker threads will simply exit right away, leaving the main thread
4344    // to do the work alone.
4345    if (!mayContainParallelRegion()) {
4346      ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
4347
4348      auto Remark = [&](OptimizationRemark OR) {
4349        return OR << "Removing unused state machine from generic-mode kernel.";
4350      };
4351      A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark);
4352
4353      return true;
4354    }
4355
4356    // Keep track in the statistics of our new shiny custom state machine.
4357    if (ReachedUnknownParallelRegions.empty()) {
4358      ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
4359
4360      auto Remark = [&](OptimizationRemark OR) {
4361        return OR << "Rewriting generic-mode kernel with a customized state "
4362                     "machine.";
4363      };
4364      A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP131", Remark);
4365    } else {
4366      ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
4367
4368      auto Remark = [&](OptimizationRemarkAnalysis OR) {
4369        return OR << "Generic-mode kernel is executed with a customized state "
4370                     "machine that requires a fallback.";
4371      };
4372      A.emitRemark<OptimizationRemarkAnalysis>(KernelInitCB, "OMP132", Remark);
4373
4374      // Tell the user why we ended up with a fallback.
4375      for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
4376        if (!UnknownParallelRegionCB)
4377          continue;
4378        auto Remark = [&](OptimizationRemarkAnalysis ORA) {
4379          return ORA << "Call may contain unknown parallel regions. Use "
4380                     << "`__attribute__((assume(\"omp_no_parallelism\")))` to "
4381                        "override.";
4382        };
4383        A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB,
4384                                                 "OMP133", Remark);
4385      }
4386    }
4387
4388    // Create all the blocks:
4389    //
4390    //                       InitCB = __kmpc_target_init(...)
4391    //                       BlockHwSize =
4392    //                         __kmpc_get_hardware_num_threads_in_block();
4393    //                       WarpSize = __kmpc_get_warp_size();
4394    //                       BlockSize = BlockHwSize - WarpSize;
4395    // IsWorkerCheckBB:      bool IsWorker = InitCB != -1;
4396    //                       if (IsWorker) {
4397    //                         if (InitCB >= BlockSize) return;
4398    // SMBeginBB:               __kmpc_barrier_simple_generic(...);
4399    //                         void *WorkFn;
4400    //                         bool Active = __kmpc_kernel_parallel(&WorkFn);
4401    //                         if (!WorkFn) return;
4402    // SMIsActiveCheckBB:       if (Active) {
4403    // SMIfCascadeCurrentBB:      if      (WorkFn == <ParFn0>)
4404    //                              ParFn0(...);
4405    // SMIfCascadeCurrentBB:      else if (WorkFn == <ParFn1>)
4406    //                              ParFn1(...);
4407    //                            ...
4408    // SMIfCascadeCurrentBB:      else
4409    //                              ((WorkFnTy*)WorkFn)(...);
4410    // SMEndParallelBB:           __kmpc_kernel_end_parallel(...);
4411    //                          }
4412    // SMDoneBB:                __kmpc_barrier_simple_generic(...);
4413    //                          goto SMBeginBB;
4414    //                       }
4415    // UserCodeEntryBB:      // user code
4416    //                       __kmpc_target_deinit(...)
4417    //
4418    auto &Ctx = getAnchorValue().getContext();
4419    Function *Kernel = getAssociatedFunction();
4420    assert(Kernel && "Expected an associated function!");
4421
4422    BasicBlock *InitBB = KernelInitCB->getParent();
4423    BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock(
4424        KernelInitCB->getNextNode(), "thread.user_code.check");
4425    BasicBlock *IsWorkerCheckBB =
4426        BasicBlock::Create(Ctx, "is_worker_check", Kernel, UserCodeEntryBB);
4427    BasicBlock *StateMachineBeginBB = BasicBlock::Create(
4428        Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB);
4429    BasicBlock *StateMachineFinishedBB = BasicBlock::Create(
4430        Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB);
4431    BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create(
4432        Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB);
4433    BasicBlock *StateMachineIfCascadeCurrentBB =
4434        BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
4435                           Kernel, UserCodeEntryBB);
4436    BasicBlock *StateMachineEndParallelBB =
4437        BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end",
4438                           Kernel, UserCodeEntryBB);
4439    BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create(
4440        Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB);
4441    A.registerManifestAddedBasicBlock(*InitBB);
4442    A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
4443    A.registerManifestAddedBasicBlock(*IsWorkerCheckBB);
4444    A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
4445    A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
4446    A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
4447    A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
4448    A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
4449    A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
4450
4451    const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4452    ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc);
4453    InitBB->getTerminator()->eraseFromParent();
4454
4455    Instruction *IsWorker =
4456        ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB,
4457                         ConstantInt::get(KernelInitCB->getType(), -1),
4458                         "thread.is_worker", InitBB);
4459    IsWorker->setDebugLoc(DLoc);
4460    BranchInst::Create(IsWorkerCheckBB, UserCodeEntryBB, IsWorker, InitBB);
4461
4462    Module &M = *Kernel->getParent();
4463    FunctionCallee BlockHwSizeFn =
4464        OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4465            M, OMPRTL___kmpc_get_hardware_num_threads_in_block);
4466    FunctionCallee WarpSizeFn =
4467        OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4468            M, OMPRTL___kmpc_get_warp_size);
4469    CallInst *BlockHwSize =
4470        CallInst::Create(BlockHwSizeFn, "block.hw_size", IsWorkerCheckBB);
4471    OMPInfoCache.setCallingConvention(BlockHwSizeFn, BlockHwSize);
4472    BlockHwSize->setDebugLoc(DLoc);
4473    CallInst *WarpSize =
4474        CallInst::Create(WarpSizeFn, "warp.size", IsWorkerCheckBB);
4475    OMPInfoCache.setCallingConvention(WarpSizeFn, WarpSize);
4476    WarpSize->setDebugLoc(DLoc);
4477    Instruction *BlockSize = BinaryOperator::CreateSub(
4478        BlockHwSize, WarpSize, "block.size", IsWorkerCheckBB);
4479    BlockSize->setDebugLoc(DLoc);
4480    Instruction *IsMainOrWorker = ICmpInst::Create(
4481        ICmpInst::ICmp, llvm::CmpInst::ICMP_SLT, KernelInitCB, BlockSize,
4482        "thread.is_main_or_worker", IsWorkerCheckBB);
4483    IsMainOrWorker->setDebugLoc(DLoc);
4484    BranchInst::Create(StateMachineBeginBB, StateMachineFinishedBB,
4485                       IsMainOrWorker, IsWorkerCheckBB);
4486
4487    // Create local storage for the work function pointer.
4488    const DataLayout &DL = M.getDataLayout();
4489    Type *VoidPtrTy = PointerType::getUnqual(Ctx);
4490    Instruction *WorkFnAI =
4491        new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr,
4492                       "worker.work_fn.addr", &Kernel->getEntryBlock().front());
4493    WorkFnAI->setDebugLoc(DLoc);
4494
4495    OMPInfoCache.OMPBuilder.updateToLocation(
4496        OpenMPIRBuilder::LocationDescription(
4497            IRBuilder<>::InsertPoint(StateMachineBeginBB,
4498                                     StateMachineBeginBB->end()),
4499            DLoc));
4500
4501    Value *Ident = KernelInfo::getIdentFromKernelEnvironment(KernelEnvC);
4502    Value *GTid = KernelInitCB;
4503
4504    FunctionCallee BarrierFn =
4505        OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4506            M, OMPRTL___kmpc_barrier_simple_generic);
4507    CallInst *Barrier =
4508        CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB);
4509    OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4510    Barrier->setDebugLoc(DLoc);
4511
4512    if (WorkFnAI->getType()->getPointerAddressSpace() !=
4513        (unsigned int)AddressSpace::Generic) {
4514      WorkFnAI = new AddrSpaceCastInst(
4515          WorkFnAI, PointerType::get(Ctx, (unsigned int)AddressSpace::Generic),
4516          WorkFnAI->getName() + ".generic", StateMachineBeginBB);
4517      WorkFnAI->setDebugLoc(DLoc);
4518    }
4519
4520    FunctionCallee KernelParallelFn =
4521        OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4522            M, OMPRTL___kmpc_kernel_parallel);
4523    CallInst *IsActiveWorker = CallInst::Create(
4524        KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB);
4525    OMPInfoCache.setCallingConvention(KernelParallelFn, IsActiveWorker);
4526    IsActiveWorker->setDebugLoc(DLoc);
4527    Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",
4528                                       StateMachineBeginBB);
4529    WorkFn->setDebugLoc(DLoc);
4530
4531    FunctionType *ParallelRegionFnTy = FunctionType::get(
4532        Type::getVoidTy(Ctx), {Type::getInt16Ty(Ctx), Type::getInt32Ty(Ctx)},
4533        false);
4534
4535    Instruction *IsDone =
4536        ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn,
4537                         Constant::getNullValue(VoidPtrTy), "worker.is_done",
4538                         StateMachineBeginBB);
4539    IsDone->setDebugLoc(DLoc);
4540    BranchInst::Create(StateMachineFinishedBB, StateMachineIsActiveCheckBB,
4541                       IsDone, StateMachineBeginBB)
4542        ->setDebugLoc(DLoc);
4543
4544    BranchInst::Create(StateMachineIfCascadeCurrentBB,
4545                       StateMachineDoneBarrierBB, IsActiveWorker,
4546                       StateMachineIsActiveCheckBB)
4547        ->setDebugLoc(DLoc);
4548
4549    Value *ZeroArg =
4550        Constant::getNullValue(ParallelRegionFnTy->getParamType(0));
4551
4552    const unsigned int WrapperFunctionArgNo = 6;
4553
4554    // Now that we have most of the CFG skeleton it is time for the if-cascade
4555    // that checks the function pointer we got from the runtime against the
4556    // parallel regions we expect, if there are any.
4557    for (int I = 0, E = ReachedKnownParallelRegions.size(); I < E; ++I) {
4558      auto *CB = ReachedKnownParallelRegions[I];
4559      auto *ParallelRegion = dyn_cast<Function>(
4560          CB->getArgOperand(WrapperFunctionArgNo)->stripPointerCasts());
4561      BasicBlock *PRExecuteBB = BasicBlock::Create(
4562          Ctx, "worker_state_machine.parallel_region.execute", Kernel,
4563          StateMachineEndParallelBB);
4564      CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB)
4565          ->setDebugLoc(DLoc);
4566      BranchInst::Create(StateMachineEndParallelBB, PRExecuteBB)
4567          ->setDebugLoc(DLoc);
4568
4569      BasicBlock *PRNextBB =
4570          BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
4571                             Kernel, StateMachineEndParallelBB);
4572      A.registerManifestAddedBasicBlock(*PRExecuteBB);
4573      A.registerManifestAddedBasicBlock(*PRNextBB);
4574
4575      // Check if we need to compare the pointer at all or if we can just
4576      // call the parallel region function.
4577      Value *IsPR;
4578      if (I + 1 < E || !ReachedUnknownParallelRegions.empty()) {
4579        Instruction *CmpI = ICmpInst::Create(
4580            ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn, ParallelRegion,
4581            "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
4582        CmpI->setDebugLoc(DLoc);
4583        IsPR = CmpI;
4584      } else {
4585        IsPR = ConstantInt::getTrue(Ctx);
4586      }
4587
4588      BranchInst::Create(PRExecuteBB, PRNextBB, IsPR,
4589                         StateMachineIfCascadeCurrentBB)
4590          ->setDebugLoc(DLoc);
4591      StateMachineIfCascadeCurrentBB = PRNextBB;
4592    }
4593
4594    // At the end of the if-cascade we place the indirect function pointer call
4595    // in case we might need it, that is if there can be parallel regions we
4596    // have not handled in the if-cascade above.
4597    if (!ReachedUnknownParallelRegions.empty()) {
4598      StateMachineIfCascadeCurrentBB->setName(
4599          "worker_state_machine.parallel_region.fallback.execute");
4600      CallInst::Create(ParallelRegionFnTy, WorkFn, {ZeroArg, GTid}, "",
4601                       StateMachineIfCascadeCurrentBB)
4602          ->setDebugLoc(DLoc);
4603    }
4604    BranchInst::Create(StateMachineEndParallelBB,
4605                       StateMachineIfCascadeCurrentBB)
4606        ->setDebugLoc(DLoc);
4607
4608    FunctionCallee EndParallelFn =
4609        OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4610            M, OMPRTL___kmpc_kernel_end_parallel);
4611    CallInst *EndParallel =
4612        CallInst::Create(EndParallelFn, {}, "", StateMachineEndParallelBB);
4613    OMPInfoCache.setCallingConvention(EndParallelFn, EndParallel);
4614    EndParallel->setDebugLoc(DLoc);
4615    BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB)
4616        ->setDebugLoc(DLoc);
4617
4618    CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB)
4619        ->setDebugLoc(DLoc);
4620    BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB)
4621        ->setDebugLoc(DLoc);
4622
4623    return true;
4624  }
4625
4626  /// Fixpoint iteration update function. Will be called every time a dependence
4627  /// changed its state (and in the beginning).
4628  ChangeStatus updateImpl(Attributor &A) override {
4629    KernelInfoState StateBefore = getState();
4630
4631    // When we leave this function this RAII will make sure the member
4632    // KernelEnvC is updated properly depending on the state. That member is
4633    // used for simplification of values and needs to be up to date at all
4634    // times.
4635    struct UpdateKernelEnvCRAII {
4636      AAKernelInfoFunction &AA;
4637
4638      UpdateKernelEnvCRAII(AAKernelInfoFunction &AA) : AA(AA) {}
4639
4640      ~UpdateKernelEnvCRAII() {
4641        if (!AA.KernelEnvC)
4642          return;
4643
4644        ConstantStruct *ExistingKernelEnvC =
4645            KernelInfo::getKernelEnvironementFromKernelInitCB(AA.KernelInitCB);
4646
4647        if (!AA.isValidState()) {
4648          AA.KernelEnvC = ExistingKernelEnvC;
4649          return;
4650        }
4651
4652        if (!AA.ReachedKnownParallelRegions.isValidState())
4653          AA.setUseGenericStateMachineOfKernelEnvironment(
4654              KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4655                  ExistingKernelEnvC));
4656
4657        if (!AA.SPMDCompatibilityTracker.isValidState())
4658          AA.setExecModeOfKernelEnvironment(
4659              KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC));
4660
4661        ConstantInt *MayUseNestedParallelismC =
4662            KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(
4663                AA.KernelEnvC);
4664        ConstantInt *NewMayUseNestedParallelismC = ConstantInt::get(
4665            MayUseNestedParallelismC->getIntegerType(), AA.NestedParallelism);
4666        AA.setMayUseNestedParallelismOfKernelEnvironment(
4667            NewMayUseNestedParallelismC);
4668      }
4669    } RAII(*this);
4670
4671    // Callback to check a read/write instruction.
4672    auto CheckRWInst = [&](Instruction &I) {
4673      // We handle calls later.
4674      if (isa<CallBase>(I))
4675        return true;
4676      // We only care about write effects.
4677      if (!I.mayWriteToMemory())
4678        return true;
4679      if (auto *SI = dyn_cast<StoreInst>(&I)) {
4680        const auto *UnderlyingObjsAA = A.getAAFor<AAUnderlyingObjects>(
4681            *this, IRPosition::value(*SI->getPointerOperand()),
4682            DepClassTy::OPTIONAL);
4683        auto *HS = A.getAAFor<AAHeapToStack>(
4684            *this, IRPosition::function(*I.getFunction()),
4685            DepClassTy::OPTIONAL);
4686        if (UnderlyingObjsAA &&
4687            UnderlyingObjsAA->forallUnderlyingObjects([&](Value &Obj) {
4688              if (AA::isAssumedThreadLocalObject(A, Obj, *this))
4689                return true;
4690              // Check for AAHeapToStack moved objects which must not be
4691              // guarded.
4692              auto *CB = dyn_cast<CallBase>(&Obj);
4693              return CB && HS && HS->isAssumedHeapToStack(*CB);
4694            }))
4695          return true;
4696      }
4697
4698      // Insert instruction that needs guarding.
4699      SPMDCompatibilityTracker.insert(&I);
4700      return true;
4701    };
4702
4703    bool UsedAssumedInformationInCheckRWInst = false;
4704    if (!SPMDCompatibilityTracker.isAtFixpoint())
4705      if (!A.checkForAllReadWriteInstructions(
4706              CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
4707        SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4708
4709    bool UsedAssumedInformationFromReachingKernels = false;
4710    if (!IsKernelEntry) {
4711      updateParallelLevels(A);
4712
4713      bool AllReachingKernelsKnown = true;
4714      updateReachingKernelEntries(A, AllReachingKernelsKnown);
4715      UsedAssumedInformationFromReachingKernels = !AllReachingKernelsKnown;
4716
4717      if (!SPMDCompatibilityTracker.empty()) {
4718        if (!ParallelLevels.isValidState())
4719          SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4720        else if (!ReachingKernelEntries.isValidState())
4721          SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4722        else {
4723          // Check if all reaching kernels agree on the mode as we can otherwise
4724          // not guard instructions. We might not be sure about the mode so we
4725          // we cannot fix the internal spmd-zation state either.
4726          int SPMD = 0, Generic = 0;
4727          for (auto *Kernel : ReachingKernelEntries) {
4728            auto *CBAA = A.getAAFor<AAKernelInfo>(
4729                *this, IRPosition::function(*Kernel), DepClassTy::OPTIONAL);
4730            if (CBAA && CBAA->SPMDCompatibilityTracker.isValidState() &&
4731                CBAA->SPMDCompatibilityTracker.isAssumed())
4732              ++SPMD;
4733            else
4734              ++Generic;
4735            if (!CBAA || !CBAA->SPMDCompatibilityTracker.isAtFixpoint())
4736              UsedAssumedInformationFromReachingKernels = true;
4737          }
4738          if (SPMD != 0 && Generic != 0)
4739            SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4740        }
4741      }
4742    }
4743
4744    // Callback to check a call instruction.
4745    bool AllParallelRegionStatesWereFixed = true;
4746    bool AllSPMDStatesWereFixed = true;
4747    auto CheckCallInst = [&](Instruction &I) {
4748      auto &CB = cast<CallBase>(I);
4749      auto *CBAA = A.getAAFor<AAKernelInfo>(
4750          *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
4751      if (!CBAA)
4752        return false;
4753      getState() ^= CBAA->getState();
4754      AllSPMDStatesWereFixed &= CBAA->SPMDCompatibilityTracker.isAtFixpoint();
4755      AllParallelRegionStatesWereFixed &=
4756          CBAA->ReachedKnownParallelRegions.isAtFixpoint();
4757      AllParallelRegionStatesWereFixed &=
4758          CBAA->ReachedUnknownParallelRegions.isAtFixpoint();
4759      return true;
4760    };
4761
4762    bool UsedAssumedInformationInCheckCallInst = false;
4763    if (!A.checkForAllCallLikeInstructions(
4764            CheckCallInst, *this, UsedAssumedInformationInCheckCallInst)) {
4765      LLVM_DEBUG(dbgs() << TAG
4766                        << "Failed to visit all call-like instructions!\n";);
4767      return indicatePessimisticFixpoint();
4768    }
4769
4770    // If we haven't used any assumed information for the reached parallel
4771    // region states we can fix it.
4772    if (!UsedAssumedInformationInCheckCallInst &&
4773        AllParallelRegionStatesWereFixed) {
4774      ReachedKnownParallelRegions.indicateOptimisticFixpoint();
4775      ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
4776    }
4777
4778    // If we haven't used any assumed information for the SPMD state we can fix
4779    // it.
4780    if (!UsedAssumedInformationInCheckRWInst &&
4781        !UsedAssumedInformationInCheckCallInst &&
4782        !UsedAssumedInformationFromReachingKernels && AllSPMDStatesWereFixed)
4783      SPMDCompatibilityTracker.indicateOptimisticFixpoint();
4784
4785    return StateBefore == getState() ? ChangeStatus::UNCHANGED
4786                                     : ChangeStatus::CHANGED;
4787  }
4788
4789private:
4790  /// Update info regarding reaching kernels.
4791  void updateReachingKernelEntries(Attributor &A,
4792                                   bool &AllReachingKernelsKnown) {
4793    auto PredCallSite = [&](AbstractCallSite ACS) {
4794      Function *Caller = ACS.getInstruction()->getFunction();
4795
4796      assert(Caller && "Caller is nullptr");
4797
4798      auto *CAA = A.getOrCreateAAFor<AAKernelInfo>(
4799          IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
4800      if (CAA && CAA->ReachingKernelEntries.isValidState()) {
4801        ReachingKernelEntries ^= CAA->ReachingKernelEntries;
4802        return true;
4803      }
4804
4805      // We lost track of the caller of the associated function, any kernel
4806      // could reach now.
4807      ReachingKernelEntries.indicatePessimisticFixpoint();
4808
4809      return true;
4810    };
4811
4812    if (!A.checkForAllCallSites(PredCallSite, *this,
4813                                true /* RequireAllCallSites */,
4814                                AllReachingKernelsKnown))
4815      ReachingKernelEntries.indicatePessimisticFixpoint();
4816  }
4817
4818  /// Update info regarding parallel levels.
4819  void updateParallelLevels(Attributor &A) {
4820    auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4821    OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
4822        OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
4823
4824    auto PredCallSite = [&](AbstractCallSite ACS) {
4825      Function *Caller = ACS.getInstruction()->getFunction();
4826
4827      assert(Caller && "Caller is nullptr");
4828
4829      auto *CAA =
4830          A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
4831      if (CAA && CAA->ParallelLevels.isValidState()) {
4832        // Any function that is called by `__kmpc_parallel_51` will not be
4833        // folded as the parallel level in the function is updated. In order to
4834        // get it right, all the analysis would depend on the implentation. That
4835        // said, if in the future any change to the implementation, the analysis
4836        // could be wrong. As a consequence, we are just conservative here.
4837        if (Caller == Parallel51RFI.Declaration) {
4838          ParallelLevels.indicatePessimisticFixpoint();
4839          return true;
4840        }
4841
4842        ParallelLevels ^= CAA->ParallelLevels;
4843
4844        return true;
4845      }
4846
4847      // We lost track of the caller of the associated function, any kernel
4848      // could reach now.
4849      ParallelLevels.indicatePessimisticFixpoint();
4850
4851      return true;
4852    };
4853
4854    bool AllCallSitesKnown = true;
4855    if (!A.checkForAllCallSites(PredCallSite, *this,
4856                                true /* RequireAllCallSites */,
4857                                AllCallSitesKnown))
4858      ParallelLevels.indicatePessimisticFixpoint();
4859  }
4860};
4861
4862/// The call site kernel info abstract attribute, basically, what can we say
4863/// about a call site with regards to the KernelInfoState. For now this simply
4864/// forwards the information from the callee.
4865struct AAKernelInfoCallSite : AAKernelInfo {
4866  AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
4867      : AAKernelInfo(IRP, A) {}
4868
4869  /// See AbstractAttribute::initialize(...).
4870  void initialize(Attributor &A) override {
4871    AAKernelInfo::initialize(A);
4872
4873    CallBase &CB = cast<CallBase>(getAssociatedValue());
4874    auto *AssumptionAA = A.getAAFor<AAAssumptionInfo>(
4875        *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
4876
4877    // Check for SPMD-mode assumptions.
4878    if (AssumptionAA && AssumptionAA->hasAssumption("ompx_spmd_amenable")) {
4879      indicateOptimisticFixpoint();
4880      return;
4881    }
4882
4883    // First weed out calls we do not care about, that is readonly/readnone
4884    // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a
4885    // parallel region or anything else we are looking for.
4886    if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) {
4887      indicateOptimisticFixpoint();
4888      return;
4889    }
4890
4891    // Next we check if we know the callee. If it is a known OpenMP function
4892    // we will handle them explicitly in the switch below. If it is not, we
4893    // will use an AAKernelInfo object on the callee to gather information and
4894    // merge that into the current state. The latter happens in the updateImpl.
4895    auto CheckCallee = [&](Function *Callee, unsigned NumCallees) {
4896      auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4897      const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
4898      if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
4899        // Unknown caller or declarations are not analyzable, we give up.
4900        if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
4901
4902          // Unknown callees might contain parallel regions, except if they have
4903          // an appropriate assumption attached.
4904          if (!AssumptionAA ||
4905              !(AssumptionAA->hasAssumption("omp_no_openmp") ||
4906                AssumptionAA->hasAssumption("omp_no_parallelism")))
4907            ReachedUnknownParallelRegions.insert(&CB);
4908
4909          // If SPMDCompatibilityTracker is not fixed, we need to give up on the
4910          // idea we can run something unknown in SPMD-mode.
4911          if (!SPMDCompatibilityTracker.isAtFixpoint()) {
4912            SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4913            SPMDCompatibilityTracker.insert(&CB);
4914          }
4915
4916          // We have updated the state for this unknown call properly, there
4917          // won't be any change so we indicate a fixpoint.
4918          indicateOptimisticFixpoint();
4919        }
4920        // If the callee is known and can be used in IPO, we will update the
4921        // state based on the callee state in updateImpl.
4922        return;
4923      }
4924      if (NumCallees > 1) {
4925        indicatePessimisticFixpoint();
4926        return;
4927      }
4928
4929      RuntimeFunction RF = It->getSecond();
4930      switch (RF) {
4931      // All the functions we know are compatible with SPMD mode.
4932      case OMPRTL___kmpc_is_spmd_exec_mode:
4933      case OMPRTL___kmpc_distribute_static_fini:
4934      case OMPRTL___kmpc_for_static_fini:
4935      case OMPRTL___kmpc_global_thread_num:
4936      case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4937      case OMPRTL___kmpc_get_hardware_num_blocks:
4938      case OMPRTL___kmpc_single:
4939      case OMPRTL___kmpc_end_single:
4940      case OMPRTL___kmpc_master:
4941      case OMPRTL___kmpc_end_master:
4942      case OMPRTL___kmpc_barrier:
4943      case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
4944      case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
4945      case OMPRTL___kmpc_error:
4946      case OMPRTL___kmpc_flush:
4947      case OMPRTL___kmpc_get_hardware_thread_id_in_block:
4948      case OMPRTL___kmpc_get_warp_size:
4949      case OMPRTL_omp_get_thread_num:
4950      case OMPRTL_omp_get_num_threads:
4951      case OMPRTL_omp_get_max_threads:
4952      case OMPRTL_omp_in_parallel:
4953      case OMPRTL_omp_get_dynamic:
4954      case OMPRTL_omp_get_cancellation:
4955      case OMPRTL_omp_get_nested:
4956      case OMPRTL_omp_get_schedule:
4957      case OMPRTL_omp_get_thread_limit:
4958      case OMPRTL_omp_get_supported_active_levels:
4959      case OMPRTL_omp_get_max_active_levels:
4960      case OMPRTL_omp_get_level:
4961      case OMPRTL_omp_get_ancestor_thread_num:
4962      case OMPRTL_omp_get_team_size:
4963      case OMPRTL_omp_get_active_level:
4964      case OMPRTL_omp_in_final:
4965      case OMPRTL_omp_get_proc_bind:
4966      case OMPRTL_omp_get_num_places:
4967      case OMPRTL_omp_get_num_procs:
4968      case OMPRTL_omp_get_place_proc_ids:
4969      case OMPRTL_omp_get_place_num:
4970      case OMPRTL_omp_get_partition_num_places:
4971      case OMPRTL_omp_get_partition_place_nums:
4972      case OMPRTL_omp_get_wtime:
4973        break;
4974      case OMPRTL___kmpc_distribute_static_init_4:
4975      case OMPRTL___kmpc_distribute_static_init_4u:
4976      case OMPRTL___kmpc_distribute_static_init_8:
4977      case OMPRTL___kmpc_distribute_static_init_8u:
4978      case OMPRTL___kmpc_for_static_init_4:
4979      case OMPRTL___kmpc_for_static_init_4u:
4980      case OMPRTL___kmpc_for_static_init_8:
4981      case OMPRTL___kmpc_for_static_init_8u: {
4982        // Check the schedule and allow static schedule in SPMD mode.
4983        unsigned ScheduleArgOpNo = 2;
4984        auto *ScheduleTypeCI =
4985            dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
4986        unsigned ScheduleTypeVal =
4987            ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
4988        switch (OMPScheduleType(ScheduleTypeVal)) {
4989        case OMPScheduleType::UnorderedStatic:
4990        case OMPScheduleType::UnorderedStaticChunked:
4991        case OMPScheduleType::OrderedDistribute:
4992        case OMPScheduleType::OrderedDistributeChunked:
4993          break;
4994        default:
4995          SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4996          SPMDCompatibilityTracker.insert(&CB);
4997          break;
4998        };
4999      } break;
5000      case OMPRTL___kmpc_target_init:
5001        KernelInitCB = &CB;
5002        break;
5003      case OMPRTL___kmpc_target_deinit:
5004        KernelDeinitCB = &CB;
5005        break;
5006      case OMPRTL___kmpc_parallel_51:
5007        if (!handleParallel51(A, CB))
5008          indicatePessimisticFixpoint();
5009        return;
5010      case OMPRTL___kmpc_omp_task:
5011        // We do not look into tasks right now, just give up.
5012        SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5013        SPMDCompatibilityTracker.insert(&CB);
5014        ReachedUnknownParallelRegions.insert(&CB);
5015        break;
5016      case OMPRTL___kmpc_alloc_shared:
5017      case OMPRTL___kmpc_free_shared:
5018        // Return without setting a fixpoint, to be resolved in updateImpl.
5019        return;
5020      default:
5021        // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
5022        // generally. However, they do not hide parallel regions.
5023        SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5024        SPMDCompatibilityTracker.insert(&CB);
5025        break;
5026      }
5027      // All other OpenMP runtime calls will not reach parallel regions so they
5028      // can be safely ignored for now. Since it is a known OpenMP runtime call
5029      // we have now modeled all effects and there is no need for any update.
5030      indicateOptimisticFixpoint();
5031    };
5032
5033    const auto *AACE =
5034        A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
5035    if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5036      CheckCallee(getAssociatedFunction(), 1);
5037      return;
5038    }
5039    const auto &OptimisticEdges = AACE->getOptimisticEdges();
5040    for (auto *Callee : OptimisticEdges) {
5041      CheckCallee(Callee, OptimisticEdges.size());
5042      if (isAtFixpoint())
5043        break;
5044    }
5045  }
5046
5047  ChangeStatus updateImpl(Attributor &A) override {
5048    // TODO: Once we have call site specific value information we can provide
5049    //       call site specific liveness information and then it makes
5050    //       sense to specialize attributes for call sites arguments instead of
5051    //       redirecting requests to the callee argument.
5052    auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
5053    KernelInfoState StateBefore = getState();
5054
5055    auto CheckCallee = [&](Function *F, int NumCallees) {
5056      const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
5057
5058      // If F is not a runtime function, propagate the AAKernelInfo of the
5059      // callee.
5060      if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
5061        const IRPosition &FnPos = IRPosition::function(*F);
5062        auto *FnAA =
5063            A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
5064        if (!FnAA)
5065          return indicatePessimisticFixpoint();
5066        if (getState() == FnAA->getState())
5067          return ChangeStatus::UNCHANGED;
5068        getState() = FnAA->getState();
5069        return ChangeStatus::CHANGED;
5070      }
5071      if (NumCallees > 1)
5072        return indicatePessimisticFixpoint();
5073
5074      CallBase &CB = cast<CallBase>(getAssociatedValue());
5075      if (It->getSecond() == OMPRTL___kmpc_parallel_51) {
5076        if (!handleParallel51(A, CB))
5077          return indicatePessimisticFixpoint();
5078        return StateBefore == getState() ? ChangeStatus::UNCHANGED
5079                                         : ChangeStatus::CHANGED;
5080      }
5081
5082      // F is a runtime function that allocates or frees memory, check
5083      // AAHeapToStack and AAHeapToShared.
5084      assert(
5085          (It->getSecond() == OMPRTL___kmpc_alloc_shared ||
5086           It->getSecond() == OMPRTL___kmpc_free_shared) &&
5087          "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
5088
5089      auto *HeapToStackAA = A.getAAFor<AAHeapToStack>(
5090          *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
5091      auto *HeapToSharedAA = A.getAAFor<AAHeapToShared>(
5092          *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
5093
5094      RuntimeFunction RF = It->getSecond();
5095
5096      switch (RF) {
5097      // If neither HeapToStack nor HeapToShared assume the call is removed,
5098      // assume SPMD incompatibility.
5099      case OMPRTL___kmpc_alloc_shared:
5100        if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) &&
5101            (!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB)))
5102          SPMDCompatibilityTracker.insert(&CB);
5103        break;
5104      case OMPRTL___kmpc_free_shared:
5105        if ((!HeapToStackAA ||
5106             !HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) &&
5107            (!HeapToSharedAA ||
5108             !HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB)))
5109          SPMDCompatibilityTracker.insert(&CB);
5110        break;
5111      default:
5112        SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5113        SPMDCompatibilityTracker.insert(&CB);
5114      }
5115      return ChangeStatus::CHANGED;
5116    };
5117
5118    const auto *AACE =
5119        A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
5120    if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5121      if (Function *F = getAssociatedFunction())
5122        CheckCallee(F, /*NumCallees=*/1);
5123    } else {
5124      const auto &OptimisticEdges = AACE->getOptimisticEdges();
5125      for (auto *Callee : OptimisticEdges) {
5126        CheckCallee(Callee, OptimisticEdges.size());
5127        if (isAtFixpoint())
5128          break;
5129      }
5130    }
5131
5132    return StateBefore == getState() ? ChangeStatus::UNCHANGED
5133                                     : ChangeStatus::CHANGED;
5134  }
5135
5136  /// Deal with a __kmpc_parallel_51 call (\p CB). Returns true if the call was
5137  /// handled, if a problem occurred, false is returned.
5138  bool handleParallel51(Attributor &A, CallBase &CB) {
5139    const unsigned int NonWrapperFunctionArgNo = 5;
5140    const unsigned int WrapperFunctionArgNo = 6;
5141    auto ParallelRegionOpArgNo = SPMDCompatibilityTracker.isAssumed()
5142                                     ? NonWrapperFunctionArgNo
5143                                     : WrapperFunctionArgNo;
5144
5145    auto *ParallelRegion = dyn_cast<Function>(
5146        CB.getArgOperand(ParallelRegionOpArgNo)->stripPointerCasts());
5147    if (!ParallelRegion)
5148      return false;
5149
5150    ReachedKnownParallelRegions.insert(&CB);
5151    /// Check nested parallelism
5152    auto *FnAA = A.getAAFor<AAKernelInfo>(
5153        *this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL);
5154    NestedParallelism |= !FnAA || !FnAA->getState().isValidState() ||
5155                         !FnAA->ReachedKnownParallelRegions.empty() ||
5156                         !FnAA->ReachedKnownParallelRegions.isValidState() ||
5157                         !FnAA->ReachedUnknownParallelRegions.isValidState() ||
5158                         !FnAA->ReachedUnknownParallelRegions.empty();
5159    return true;
5160  }
5161};
5162
5163struct AAFoldRuntimeCall
5164    : public StateWrapper<BooleanState, AbstractAttribute> {
5165  using Base = StateWrapper<BooleanState, AbstractAttribute>;
5166
5167  AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
5168
5169  /// Statistics are tracked as part of manifest for now.
5170  void trackStatistics() const override {}
5171
5172  /// Create an abstract attribute biew for the position \p IRP.
5173  static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP,
5174                                              Attributor &A);
5175
5176  /// See AbstractAttribute::getName()
5177  const std::string getName() const override { return "AAFoldRuntimeCall"; }
5178
5179  /// See AbstractAttribute::getIdAddr()
5180  const char *getIdAddr() const override { return &ID; }
5181
5182  /// This function should return true if the type of the \p AA is
5183  /// AAFoldRuntimeCall
5184  static bool classof(const AbstractAttribute *AA) {
5185    return (AA->getIdAddr() == &ID);
5186  }
5187
5188  static const char ID;
5189};
5190
5191struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
5192  AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A)
5193      : AAFoldRuntimeCall(IRP, A) {}
5194
5195  /// See AbstractAttribute::getAsStr()
5196  const std::string getAsStr(Attributor *) const override {
5197    if (!isValidState())
5198      return "<invalid>";
5199
5200    std::string Str("simplified value: ");
5201
5202    if (!SimplifiedValue)
5203      return Str + std::string("none");
5204
5205    if (!*SimplifiedValue)
5206      return Str + std::string("nullptr");
5207
5208    if (ConstantInt *CI = dyn_cast<ConstantInt>(*SimplifiedValue))
5209      return Str + std::to_string(CI->getSExtValue());
5210
5211    return Str + std::string("unknown");
5212  }
5213
5214  void initialize(Attributor &A) override {
5215    if (DisableOpenMPOptFolding)
5216      indicatePessimisticFixpoint();
5217
5218    Function *Callee = getAssociatedFunction();
5219
5220    auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
5221    const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
5222    assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
5223           "Expected a known OpenMP runtime function");
5224
5225    RFKind = It->getSecond();
5226
5227    CallBase &CB = cast<CallBase>(getAssociatedValue());
5228    A.registerSimplificationCallback(
5229        IRPosition::callsite_returned(CB),
5230        [&](const IRPosition &IRP, const AbstractAttribute *AA,
5231            bool &UsedAssumedInformation) -> std::optional<Value *> {
5232          assert((isValidState() ||
5233                  (SimplifiedValue && *SimplifiedValue == nullptr)) &&
5234                 "Unexpected invalid state!");
5235
5236          if (!isAtFixpoint()) {
5237            UsedAssumedInformation = true;
5238            if (AA)
5239              A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
5240          }
5241          return SimplifiedValue;
5242        });
5243  }
5244
5245  ChangeStatus updateImpl(Attributor &A) override {
5246    ChangeStatus Changed = ChangeStatus::UNCHANGED;
5247    switch (RFKind) {
5248    case OMPRTL___kmpc_is_spmd_exec_mode:
5249      Changed |= foldIsSPMDExecMode(A);
5250      break;
5251    case OMPRTL___kmpc_parallel_level:
5252      Changed |= foldParallelLevel(A);
5253      break;
5254    case OMPRTL___kmpc_get_hardware_num_threads_in_block:
5255      Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit");
5256      break;
5257    case OMPRTL___kmpc_get_hardware_num_blocks:
5258      Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams");
5259      break;
5260    default:
5261      llvm_unreachable("Unhandled OpenMP runtime function!");
5262    }
5263
5264    return Changed;
5265  }
5266
5267  ChangeStatus manifest(Attributor &A) override {
5268    ChangeStatus Changed = ChangeStatus::UNCHANGED;
5269
5270    if (SimplifiedValue && *SimplifiedValue) {
5271      Instruction &I = *getCtxI();
5272      A.changeAfterManifest(IRPosition::inst(I), **SimplifiedValue);
5273      A.deleteAfterManifest(I);
5274
5275      CallBase *CB = dyn_cast<CallBase>(&I);
5276      auto Remark = [&](OptimizationRemark OR) {
5277        if (auto *C = dyn_cast<ConstantInt>(*SimplifiedValue))
5278          return OR << "Replacing OpenMP runtime call "
5279                    << CB->getCalledFunction()->getName() << " with "
5280                    << ore::NV("FoldedValue", C->getZExtValue()) << ".";
5281        return OR << "Replacing OpenMP runtime call "
5282                  << CB->getCalledFunction()->getName() << ".";
5283      };
5284
5285      if (CB && EnableVerboseRemarks)
5286        A.emitRemark<OptimizationRemark>(CB, "OMP180", Remark);
5287
5288      LLVM_DEBUG(dbgs() << TAG << "Replacing runtime call: " << I << " with "
5289                        << **SimplifiedValue << "\n");
5290
5291      Changed = ChangeStatus::CHANGED;
5292    }
5293
5294    return Changed;
5295  }
5296
5297  ChangeStatus indicatePessimisticFixpoint() override {
5298    SimplifiedValue = nullptr;
5299    return AAFoldRuntimeCall::indicatePessimisticFixpoint();
5300  }
5301
5302private:
5303  /// Fold __kmpc_is_spmd_exec_mode into a constant if possible.
5304  ChangeStatus foldIsSPMDExecMode(Attributor &A) {
5305    std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5306
5307    unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5308    unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5309    auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5310        *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5311
5312    if (!CallerKernelInfoAA ||
5313        !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5314      return indicatePessimisticFixpoint();
5315
5316    for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5317      auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
5318                                          DepClassTy::REQUIRED);
5319
5320      if (!AA || !AA->isValidState()) {
5321        SimplifiedValue = nullptr;
5322        return indicatePessimisticFixpoint();
5323      }
5324
5325      if (AA->SPMDCompatibilityTracker.isAssumed()) {
5326        if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5327          ++KnownSPMDCount;
5328        else
5329          ++AssumedSPMDCount;
5330      } else {
5331        if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5332          ++KnownNonSPMDCount;
5333        else
5334          ++AssumedNonSPMDCount;
5335      }
5336    }
5337
5338    if ((AssumedSPMDCount + KnownSPMDCount) &&
5339        (AssumedNonSPMDCount + KnownNonSPMDCount))
5340      return indicatePessimisticFixpoint();
5341
5342    auto &Ctx = getAnchorValue().getContext();
5343    if (KnownSPMDCount || AssumedSPMDCount) {
5344      assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5345             "Expected only SPMD kernels!");
5346      // All reaching kernels are in SPMD mode. Update all function calls to
5347      // __kmpc_is_spmd_exec_mode to 1.
5348      SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
5349    } else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
5350      assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5351             "Expected only non-SPMD kernels!");
5352      // All reaching kernels are in non-SPMD mode. Update all function
5353      // calls to __kmpc_is_spmd_exec_mode to 0.
5354      SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false);
5355    } else {
5356      // We have empty reaching kernels, therefore we cannot tell if the
5357      // associated call site can be folded. At this moment, SimplifiedValue
5358      // must be none.
5359      assert(!SimplifiedValue && "SimplifiedValue should be none");
5360    }
5361
5362    return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5363                                                    : ChangeStatus::CHANGED;
5364  }
5365
5366  /// Fold __kmpc_parallel_level into a constant if possible.
5367  ChangeStatus foldParallelLevel(Attributor &A) {
5368    std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5369
5370    auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5371        *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5372
5373    if (!CallerKernelInfoAA ||
5374        !CallerKernelInfoAA->ParallelLevels.isValidState())
5375      return indicatePessimisticFixpoint();
5376
5377    if (!CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5378      return indicatePessimisticFixpoint();
5379
5380    if (CallerKernelInfoAA->ReachingKernelEntries.empty()) {
5381      assert(!SimplifiedValue &&
5382             "SimplifiedValue should keep none at this point");
5383      return ChangeStatus::UNCHANGED;
5384    }
5385
5386    unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5387    unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5388    for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5389      auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
5390                                          DepClassTy::REQUIRED);
5391      if (!AA || !AA->SPMDCompatibilityTracker.isValidState())
5392        return indicatePessimisticFixpoint();
5393
5394      if (AA->SPMDCompatibilityTracker.isAssumed()) {
5395        if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5396          ++KnownSPMDCount;
5397        else
5398          ++AssumedSPMDCount;
5399      } else {
5400        if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5401          ++KnownNonSPMDCount;
5402        else
5403          ++AssumedNonSPMDCount;
5404      }
5405    }
5406
5407    if ((AssumedSPMDCount + KnownSPMDCount) &&
5408        (AssumedNonSPMDCount + KnownNonSPMDCount))
5409      return indicatePessimisticFixpoint();
5410
5411    auto &Ctx = getAnchorValue().getContext();
5412    // If the caller can only be reached by SPMD kernel entries, the parallel
5413    // level is 1. Similarly, if the caller can only be reached by non-SPMD
5414    // kernel entries, it is 0.
5415    if (AssumedSPMDCount || KnownSPMDCount) {
5416      assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5417             "Expected only SPMD kernels!");
5418      SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
5419    } else {
5420      assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5421             "Expected only non-SPMD kernels!");
5422      SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);
5423    }
5424    return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5425                                                    : ChangeStatus::CHANGED;
5426  }
5427
5428  ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) {
5429    // Specialize only if all the calls agree with the attribute constant value
5430    int32_t CurrentAttrValue = -1;
5431    std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5432
5433    auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5434        *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5435
5436    if (!CallerKernelInfoAA ||
5437        !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5438      return indicatePessimisticFixpoint();
5439
5440    // Iterate over the kernels that reach this function
5441    for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5442      int32_t NextAttrVal = K->getFnAttributeAsParsedInteger(Attr, -1);
5443
5444      if (NextAttrVal == -1 ||
5445          (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
5446        return indicatePessimisticFixpoint();
5447      CurrentAttrValue = NextAttrVal;
5448    }
5449
5450    if (CurrentAttrValue != -1) {
5451      auto &Ctx = getAnchorValue().getContext();
5452      SimplifiedValue =
5453          ConstantInt::get(Type::getInt32Ty(Ctx), CurrentAttrValue);
5454    }
5455    return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5456                                                    : ChangeStatus::CHANGED;
5457  }
5458
5459  /// An optional value the associated value is assumed to fold to. That is, we
5460  /// assume the associated value (which is a call) can be replaced by this
5461  /// simplified value.
5462  std::optional<Value *> SimplifiedValue;
5463
5464  /// The runtime function kind of the callee of the associated call site.
5465  RuntimeFunction RFKind;
5466};
5467
5468} // namespace
5469
5470/// Register folding callsite
5471void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) {
5472  auto &RFI = OMPInfoCache.RFIs[RF];
5473  RFI.foreachUse(SCC, [&](Use &U, Function &F) {
5474    CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI);
5475    if (!CI)
5476      return false;
5477    A.getOrCreateAAFor<AAFoldRuntimeCall>(
5478        IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
5479        DepClassTy::NONE, /* ForceUpdate */ false,
5480        /* UpdateAfterInit */ false);
5481    return false;
5482  });
5483}
5484
5485void OpenMPOpt::registerAAs(bool IsModulePass) {
5486  if (SCC.empty())
5487    return;
5488
5489  if (IsModulePass) {
5490    // Ensure we create the AAKernelInfo AAs first and without triggering an
5491    // update. This will make sure we register all value simplification
5492    // callbacks before any other AA has the chance to create an AAValueSimplify
5493    // or similar.
5494    auto CreateKernelInfoCB = [&](Use &, Function &Kernel) {
5495      A.getOrCreateAAFor<AAKernelInfo>(
5496          IRPosition::function(Kernel), /* QueryingAA */ nullptr,
5497          DepClassTy::NONE, /* ForceUpdate */ false,
5498          /* UpdateAfterInit */ false);
5499      return false;
5500    };
5501    OMPInformationCache::RuntimeFunctionInfo &InitRFI =
5502        OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
5503    InitRFI.foreachUse(SCC, CreateKernelInfoCB);
5504
5505    registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode);
5506    registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);
5507    registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block);
5508    registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks);
5509  }
5510
5511  // Create CallSite AA for all Getters.
5512  if (DeduceICVValues) {
5513    for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
5514      auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
5515
5516      auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
5517
5518      auto CreateAA = [&](Use &U, Function &Caller) {
5519        CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
5520        if (!CI)
5521          return false;
5522
5523        auto &CB = cast<CallBase>(*CI);
5524
5525        IRPosition CBPos = IRPosition::callsite_function(CB);
5526        A.getOrCreateAAFor<AAICVTracker>(CBPos);
5527        return false;
5528      };
5529
5530      GetterRFI.foreachUse(SCC, CreateAA);
5531    }
5532  }
5533
5534  // Create an ExecutionDomain AA for every function and a HeapToStack AA for
5535  // every function if there is a device kernel.
5536  if (!isOpenMPDevice(M))
5537    return;
5538
5539  for (auto *F : SCC) {
5540    if (F->isDeclaration())
5541      continue;
5542
5543    // We look at internal functions only on-demand but if any use is not a
5544    // direct call or outside the current set of analyzed functions, we have
5545    // to do it eagerly.
5546    if (F->hasLocalLinkage()) {
5547      if (llvm::all_of(F->uses(), [this](const Use &U) {
5548            const auto *CB = dyn_cast<CallBase>(U.getUser());
5549            return CB && CB->isCallee(&U) &&
5550                   A.isRunOn(const_cast<Function *>(CB->getCaller()));
5551          }))
5552        continue;
5553    }
5554    registerAAsForFunction(A, *F);
5555  }
5556}
5557
5558void OpenMPOpt::registerAAsForFunction(Attributor &A, const Function &F) {
5559  if (!DisableOpenMPOptDeglobalization)
5560    A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F));
5561  A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(F));
5562  if (!DisableOpenMPOptDeglobalization)
5563    A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(F));
5564  if (F.hasFnAttribute(Attribute::Convergent))
5565    A.getOrCreateAAFor<AANonConvergent>(IRPosition::function(F));
5566
5567  for (auto &I : instructions(F)) {
5568    if (auto *LI = dyn_cast<LoadInst>(&I)) {
5569      bool UsedAssumedInformation = false;
5570      A.getAssumedSimplified(IRPosition::value(*LI), /* AA */ nullptr,
5571                             UsedAssumedInformation, AA::Interprocedural);
5572      continue;
5573    }
5574    if (auto *CI = dyn_cast<CallBase>(&I)) {
5575      if (CI->isIndirectCall())
5576        A.getOrCreateAAFor<AAIndirectCallInfo>(
5577            IRPosition::callsite_function(*CI));
5578    }
5579    if (auto *SI = dyn_cast<StoreInst>(&I)) {
5580      A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*SI));
5581      continue;
5582    }
5583    if (auto *FI = dyn_cast<FenceInst>(&I)) {
5584      A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*FI));
5585      continue;
5586    }
5587    if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
5588      if (II->getIntrinsicID() == Intrinsic::assume) {
5589        A.getOrCreateAAFor<AAPotentialValues>(
5590            IRPosition::value(*II->getArgOperand(0)));
5591        continue;
5592      }
5593    }
5594  }
5595}
5596
5597const char AAICVTracker::ID = 0;
5598const char AAKernelInfo::ID = 0;
5599const char AAExecutionDomain::ID = 0;
5600const char AAHeapToShared::ID = 0;
5601const char AAFoldRuntimeCall::ID = 0;
5602
5603AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
5604                                              Attributor &A) {
5605  AAICVTracker *AA = nullptr;
5606  switch (IRP.getPositionKind()) {
5607  case IRPosition::IRP_INVALID:
5608  case IRPosition::IRP_FLOAT:
5609  case IRPosition::IRP_ARGUMENT:
5610  case IRPosition::IRP_CALL_SITE_ARGUMENT:
5611    llvm_unreachable("ICVTracker can only be created for function position!");
5612  case IRPosition::IRP_RETURNED:
5613    AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
5614    break;
5615  case IRPosition::IRP_CALL_SITE_RETURNED:
5616    AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
5617    break;
5618  case IRPosition::IRP_CALL_SITE:
5619    AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
5620    break;
5621  case IRPosition::IRP_FUNCTION:
5622    AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
5623    break;
5624  }
5625
5626  return *AA;
5627}
5628
5629AAExecutionDomain &AAExecutionDomain::createForPosition(const IRPosition &IRP,
5630                                                        Attributor &A) {
5631  AAExecutionDomainFunction *AA = nullptr;
5632  switch (IRP.getPositionKind()) {
5633  case IRPosition::IRP_INVALID:
5634  case IRPosition::IRP_FLOAT:
5635  case IRPosition::IRP_ARGUMENT:
5636  case IRPosition::IRP_CALL_SITE_ARGUMENT:
5637  case IRPosition::IRP_RETURNED:
5638  case IRPosition::IRP_CALL_SITE_RETURNED:
5639  case IRPosition::IRP_CALL_SITE:
5640    llvm_unreachable(
5641        "AAExecutionDomain can only be created for function position!");
5642  case IRPosition::IRP_FUNCTION:
5643    AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);
5644    break;
5645  }
5646
5647  return *AA;
5648}
5649
5650AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
5651                                                  Attributor &A) {
5652  AAHeapToSharedFunction *AA = nullptr;
5653  switch (IRP.getPositionKind()) {
5654  case IRPosition::IRP_INVALID:
5655  case IRPosition::IRP_FLOAT:
5656  case IRPosition::IRP_ARGUMENT:
5657  case IRPosition::IRP_CALL_SITE_ARGUMENT:
5658  case IRPosition::IRP_RETURNED:
5659  case IRPosition::IRP_CALL_SITE_RETURNED:
5660  case IRPosition::IRP_CALL_SITE:
5661    llvm_unreachable(
5662        "AAHeapToShared can only be created for function position!");
5663  case IRPosition::IRP_FUNCTION:
5664    AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A);
5665    break;
5666  }
5667
5668  return *AA;
5669}
5670
5671AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
5672                                              Attributor &A) {
5673  AAKernelInfo *AA = nullptr;
5674  switch (IRP.getPositionKind()) {
5675  case IRPosition::IRP_INVALID:
5676  case IRPosition::IRP_FLOAT:
5677  case IRPosition::IRP_ARGUMENT:
5678  case IRPosition::IRP_RETURNED:
5679  case IRPosition::IRP_CALL_SITE_RETURNED:
5680  case IRPosition::IRP_CALL_SITE_ARGUMENT:
5681    llvm_unreachable("KernelInfo can only be created for function position!");
5682  case IRPosition::IRP_CALL_SITE:
5683    AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A);
5684    break;
5685  case IRPosition::IRP_FUNCTION:
5686    AA = new (A.Allocator) AAKernelInfoFunction(IRP, A);
5687    break;
5688  }
5689
5690  return *AA;
5691}
5692
5693AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP,
5694                                                        Attributor &A) {
5695  AAFoldRuntimeCall *AA = nullptr;
5696  switch (IRP.getPositionKind()) {
5697  case IRPosition::IRP_INVALID:
5698  case IRPosition::IRP_FLOAT:
5699  case IRPosition::IRP_ARGUMENT:
5700  case IRPosition::IRP_RETURNED:
5701  case IRPosition::IRP_FUNCTION:
5702  case IRPosition::IRP_CALL_SITE:
5703  case IRPosition::IRP_CALL_SITE_ARGUMENT:
5704    llvm_unreachable("KernelInfo can only be created for call site position!");
5705  case IRPosition::IRP_CALL_SITE_RETURNED:
5706    AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A);
5707    break;
5708  }
5709
5710  return *AA;
5711}
5712
5713PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) {
5714  if (!containsOpenMP(M))
5715    return PreservedAnalyses::all();
5716  if (DisableOpenMPOptimizations)
5717    return PreservedAnalyses::all();
5718
5719  FunctionAnalysisManager &FAM =
5720      AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
5721  KernelSet Kernels = getDeviceKernels(M);
5722
5723  if (PrintModuleBeforeOptimizations)
5724    LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt Module Pass:\n" << M);
5725
5726  auto IsCalled = [&](Function &F) {
5727    if (Kernels.contains(&F))
5728      return true;
5729    for (const User *U : F.users())
5730      if (!isa<BlockAddress>(U))
5731        return true;
5732    return false;
5733  };
5734
5735  auto EmitRemark = [&](Function &F) {
5736    auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
5737    ORE.emit([&]() {
5738      OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "OMP140", &F);
5739      return ORA << "Could not internalize function. "
5740                 << "Some optimizations may not be possible. [OMP140]";
5741    });
5742  };
5743
5744  bool Changed = false;
5745
5746  // Create internal copies of each function if this is a kernel Module. This
5747  // allows iterprocedural passes to see every call edge.
5748  DenseMap<Function *, Function *> InternalizedMap;
5749  if (isOpenMPDevice(M)) {
5750    SmallPtrSet<Function *, 16> InternalizeFns;
5751    for (Function &F : M)
5752      if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F) &&
5753          !DisableInternalization) {
5754        if (Attributor::isInternalizable(F)) {
5755          InternalizeFns.insert(&F);
5756        } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) {
5757          EmitRemark(F);
5758        }
5759      }
5760
5761    Changed |=
5762        Attributor::internalizeFunctions(InternalizeFns, InternalizedMap);
5763  }
5764
5765  // Look at every function in the Module unless it was internalized.
5766  SetVector<Function *> Functions;
5767  SmallVector<Function *, 16> SCC;
5768  for (Function &F : M)
5769    if (!F.isDeclaration() && !InternalizedMap.lookup(&F)) {
5770      SCC.push_back(&F);
5771      Functions.insert(&F);
5772    }
5773
5774  if (SCC.empty())
5775    return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
5776
5777  AnalysisGetter AG(FAM);
5778
5779  auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5780    return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
5781  };
5782
5783  BumpPtrAllocator Allocator;
5784  CallGraphUpdater CGUpdater;
5785
5786  bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
5787                  LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink;
5788  OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ nullptr, PostLink);
5789
5790  unsigned MaxFixpointIterations =
5791      (isOpenMPDevice(M)) ? SetFixpointIterations : 32;
5792
5793  AttributorConfig AC(CGUpdater);
5794  AC.DefaultInitializeLiveInternals = false;
5795  AC.IsModulePass = true;
5796  AC.RewriteSignatures = false;
5797  AC.MaxFixpointIterations = MaxFixpointIterations;
5798  AC.OREGetter = OREGetter;
5799  AC.PassName = DEBUG_TYPE;
5800  AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;
5801  AC.IPOAmendableCB = [](const Function &F) {
5802    return F.hasFnAttribute("kernel");
5803  };
5804
5805  Attributor A(Functions, InfoCache, AC);
5806
5807  OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5808  Changed |= OMPOpt.run(true);
5809
5810  // Optionally inline device functions for potentially better performance.
5811  if (AlwaysInlineDeviceFunctions && isOpenMPDevice(M))
5812    for (Function &F : M)
5813      if (!F.isDeclaration() && !Kernels.contains(&F) &&
5814          !F.hasFnAttribute(Attribute::NoInline))
5815        F.addFnAttr(Attribute::AlwaysInline);
5816
5817  if (PrintModuleAfterOptimizations)
5818    LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt Module Pass:\n" << M);
5819
5820  if (Changed)
5821    return PreservedAnalyses::none();
5822
5823  return PreservedAnalyses::all();
5824}
5825
5826PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C,
5827                                          CGSCCAnalysisManager &AM,
5828                                          LazyCallGraph &CG,
5829                                          CGSCCUpdateResult &UR) {
5830  if (!containsOpenMP(*C.begin()->getFunction().getParent()))
5831    return PreservedAnalyses::all();
5832  if (DisableOpenMPOptimizations)
5833    return PreservedAnalyses::all();
5834
5835  SmallVector<Function *, 16> SCC;
5836  // If there are kernels in the module, we have to run on all SCC's.
5837  for (LazyCallGraph::Node &N : C) {
5838    Function *Fn = &N.getFunction();
5839    SCC.push_back(Fn);
5840  }
5841
5842  if (SCC.empty())
5843    return PreservedAnalyses::all();
5844
5845  Module &M = *C.begin()->getFunction().getParent();
5846
5847  if (PrintModuleBeforeOptimizations)
5848    LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt CGSCC Pass:\n" << M);
5849
5850  KernelSet Kernels = getDeviceKernels(M);
5851
5852  FunctionAnalysisManager &FAM =
5853      AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
5854
5855  AnalysisGetter AG(FAM);
5856
5857  auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5858    return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
5859  };
5860
5861  BumpPtrAllocator Allocator;
5862  CallGraphUpdater CGUpdater;
5863  CGUpdater.initialize(CG, C, AM, UR);
5864
5865  bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
5866                  LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink;
5867  SetVector<Function *> Functions(SCC.begin(), SCC.end());
5868  OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
5869                                /*CGSCC*/ &Functions, PostLink);
5870
5871  unsigned MaxFixpointIterations =
5872      (isOpenMPDevice(M)) ? SetFixpointIterations : 32;
5873
5874  AttributorConfig AC(CGUpdater);
5875  AC.DefaultInitializeLiveInternals = false;
5876  AC.IsModulePass = false;
5877  AC.RewriteSignatures = false;
5878  AC.MaxFixpointIterations = MaxFixpointIterations;
5879  AC.OREGetter = OREGetter;
5880  AC.PassName = DEBUG_TYPE;
5881  AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;
5882
5883  Attributor A(Functions, InfoCache, AC);
5884
5885  OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5886  bool Changed = OMPOpt.run(false);
5887
5888  if (PrintModuleAfterOptimizations)
5889    LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
5890
5891  if (Changed)
5892    return PreservedAnalyses::none();
5893
5894  return PreservedAnalyses::all();
5895}
5896
5897bool llvm::omp::isOpenMPKernel(Function &Fn) {
5898  return Fn.hasFnAttribute("kernel");
5899}
5900
5901KernelSet llvm::omp::getDeviceKernels(Module &M) {
5902  // TODO: Create a more cross-platform way of determining device kernels.
5903  NamedMDNode *MD = M.getNamedMetadata("nvvm.annotations");
5904  KernelSet Kernels;
5905
5906  if (!MD)
5907    return Kernels;
5908
5909  for (auto *Op : MD->operands()) {
5910    if (Op->getNumOperands() < 2)
5911      continue;
5912    MDString *KindID = dyn_cast<MDString>(Op->getOperand(1));
5913    if (!KindID || KindID->getString() != "kernel")
5914      continue;
5915
5916    Function *KernelFn =
5917        mdconst::dyn_extract_or_null<Function>(Op->getOperand(0));
5918    if (!KernelFn)
5919      continue;
5920
5921    // We are only interested in OpenMP target regions. Others, such as kernels
5922    // generated by CUDA but linked together, are not interesting to this pass.
5923    if (isOpenMPKernel(*KernelFn)) {
5924      ++NumOpenMPTargetRegionKernels;
5925      Kernels.insert(KernelFn);
5926    } else
5927      ++NumNonOpenMPTargetRegionKernels;
5928  }
5929
5930  return Kernels;
5931}
5932
5933bool llvm::omp::containsOpenMP(Module &M) {
5934  Metadata *MD = M.getModuleFlag("openmp");
5935  if (!MD)
5936    return false;
5937
5938  return true;
5939}
5940
5941bool llvm::omp::isOpenMPDevice(Module &M) {
5942  Metadata *MD = M.getModuleFlag("openmp-device");
5943  if (!MD)
5944    return false;
5945
5946  return true;
5947}
5948