1//===- ConvergenceVerifier.cpp - Verify convergence control -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/IR/ConvergenceVerifier.h"
10#include "llvm/IR/Dominators.h"
11#include "llvm/IR/GenericConvergenceVerifierImpl.h"
12#include "llvm/IR/Instructions.h"
13#include "llvm/IR/SSAContext.h"
14
15using namespace llvm;
16
17template <>
18const Instruction *
19GenericConvergenceVerifier<SSAContext>::findAndCheckConvergenceTokenUsed(
20    const Instruction &I) {
21  auto *CB = dyn_cast<CallBase>(&I);
22  if (!CB)
23    return nullptr;
24
25  unsigned Count =
26      CB->countOperandBundlesOfType(LLVMContext::OB_convergencectrl);
27  CheckOrNull(Count <= 1,
28              "The 'convergencectrl' bundle can occur at most once on a call",
29              {Context.print(CB)});
30  if (!Count)
31    return nullptr;
32
33  auto Bundle = CB->getOperandBundle(LLVMContext::OB_convergencectrl);
34  CheckOrNull(Bundle->Inputs.size() == 1 &&
35                  Bundle->Inputs[0]->getType()->isTokenTy(),
36              "The 'convergencectrl' bundle requires exactly one token use.",
37              {Context.print(CB)});
38  auto *Token = Bundle->Inputs[0].get();
39  auto *Def = dyn_cast<Instruction>(Token);
40
41  CheckOrNull(
42      Def && isConvergenceControlIntrinsic(SSAContext::getIntrinsicID(*Def)),
43      "Convergence control tokens can only be produced by calls to the "
44      "convergence control intrinsics.",
45      {Context.print(Token), Context.print(&I)});
46
47  if (Def)
48    Tokens[&I] = Def;
49
50  return Def;
51}
52
53template <>
54bool GenericConvergenceVerifier<SSAContext>::isInsideConvergentFunction(
55    const InstructionT &I) {
56  auto *F = I.getFunction();
57  return F->isConvergent();
58}
59
60template <>
61bool GenericConvergenceVerifier<SSAContext>::isConvergent(
62    const InstructionT &I) {
63  if (auto *CB = dyn_cast<CallBase>(&I)) {
64    return CB->isConvergent();
65  }
66  return false;
67}
68
69template class llvm::GenericConvergenceVerifier<SSAContext>;
70