1//===-- AbstractCallSite.cpp - Implementation of abstract call sites ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements abstract call sites which unify the interface for
10// direct, indirect, and callback call sites.
11//
12// For more information see:
13// https://llvm.org/devmtg/2018-10/talk-abstracts.html#talk20
14//
15//===----------------------------------------------------------------------===//
16
17#include "llvm/ADT/Statistic.h"
18#include "llvm/ADT/StringSwitch.h"
19#include "llvm/IR/CallSite.h"
20#include "llvm/Support/Debug.h"
21
22using namespace llvm;
23
24#define DEBUG_TYPE "abstract-call-sites"
25
26STATISTIC(NumCallbackCallSites, "Number of callback call sites created");
27STATISTIC(NumDirectAbstractCallSites,
28          "Number of direct abstract call sites created");
29STATISTIC(NumInvalidAbstractCallSitesUnknownUse,
30          "Number of invalid abstract call sites created (unknown use)");
31STATISTIC(NumInvalidAbstractCallSitesUnknownCallee,
32          "Number of invalid abstract call sites created (unknown callee)");
33STATISTIC(NumInvalidAbstractCallSitesNoCallback,
34          "Number of invalid abstract call sites created (no callback)");
35
36void AbstractCallSite::getCallbackUses(ImmutableCallSite ICS,
37                                       SmallVectorImpl<const Use *> &CBUses) {
38  const Function *Callee = ICS.getCalledFunction();
39  if (!Callee)
40    return;
41
42  MDNode *CallbackMD = Callee->getMetadata(LLVMContext::MD_callback);
43  if (!CallbackMD)
44    return;
45
46  for (const MDOperand &Op : CallbackMD->operands()) {
47    MDNode *OpMD = cast<MDNode>(Op.get());
48    auto *CBCalleeIdxAsCM = cast<ConstantAsMetadata>(OpMD->getOperand(0));
49    uint64_t CBCalleeIdx =
50        cast<ConstantInt>(CBCalleeIdxAsCM->getValue())->getZExtValue();
51    CBUses.push_back(ICS.arg_begin() + CBCalleeIdx);
52  }
53}
54
55/// Create an abstract call site from a use.
56AbstractCallSite::AbstractCallSite(const Use *U) : CS(U->getUser()) {
57
58  // First handle unknown users.
59  if (!CS) {
60
61    // If the use is actually in a constant cast expression which itself
62    // has only one use, we look through the constant cast expression.
63    // This happens by updating the use @p U to the use of the constant
64    // cast expression and afterwards re-initializing CS accordingly.
65    if (ConstantExpr *CE = dyn_cast<ConstantExpr>(U->getUser()))
66      if (CE->getNumUses() == 1 && CE->isCast()) {
67        U = &*CE->use_begin();
68        CS = CallSite(U->getUser());
69      }
70
71    if (!CS) {
72      NumInvalidAbstractCallSitesUnknownUse++;
73      return;
74    }
75  }
76
77  // Then handle direct or indirect calls. Thus, if U is the callee of the
78  // call site CS it is not a callback and we are done.
79  if (CS.isCallee(U)) {
80    NumDirectAbstractCallSites++;
81    return;
82  }
83
84  // If we cannot identify the broker function we cannot create a callback and
85  // invalidate the abstract call site.
86  Function *Callee = CS.getCalledFunction();
87  if (!Callee) {
88    NumInvalidAbstractCallSitesUnknownCallee++;
89    CS = CallSite();
90    return;
91  }
92
93  MDNode *CallbackMD = Callee->getMetadata(LLVMContext::MD_callback);
94  if (!CallbackMD) {
95    NumInvalidAbstractCallSitesNoCallback++;
96    CS = CallSite();
97    return;
98  }
99
100  unsigned UseIdx = CS.getArgumentNo(U);
101  MDNode *CallbackEncMD = nullptr;
102  for (const MDOperand &Op : CallbackMD->operands()) {
103    MDNode *OpMD = cast<MDNode>(Op.get());
104    auto *CBCalleeIdxAsCM = cast<ConstantAsMetadata>(OpMD->getOperand(0));
105    uint64_t CBCalleeIdx =
106        cast<ConstantInt>(CBCalleeIdxAsCM->getValue())->getZExtValue();
107    if (CBCalleeIdx != UseIdx)
108      continue;
109    CallbackEncMD = OpMD;
110    break;
111  }
112
113  if (!CallbackEncMD) {
114    NumInvalidAbstractCallSitesNoCallback++;
115    CS = CallSite();
116    return;
117  }
118
119  NumCallbackCallSites++;
120
121  assert(CallbackEncMD->getNumOperands() >= 2 && "Incomplete !callback metadata");
122
123  unsigned NumCallOperands = CS.getNumArgOperands();
124  // Skip the var-arg flag at the end when reading the metadata.
125  for (unsigned u = 0, e = CallbackEncMD->getNumOperands() - 1; u < e; u++) {
126    Metadata *OpAsM = CallbackEncMD->getOperand(u).get();
127    auto *OpAsCM = cast<ConstantAsMetadata>(OpAsM);
128    assert(OpAsCM->getType()->isIntegerTy(64) &&
129           "Malformed !callback metadata");
130
131    int64_t Idx = cast<ConstantInt>(OpAsCM->getValue())->getSExtValue();
132    assert(-1 <= Idx && Idx <= NumCallOperands &&
133           "Out-of-bounds !callback metadata index");
134
135    CI.ParameterEncoding.push_back(Idx);
136  }
137
138  if (!Callee->isVarArg())
139    return;
140
141  Metadata *VarArgFlagAsM =
142      CallbackEncMD->getOperand(CallbackEncMD->getNumOperands() - 1).get();
143  auto *VarArgFlagAsCM = cast<ConstantAsMetadata>(VarArgFlagAsM);
144  assert(VarArgFlagAsCM->getType()->isIntegerTy(1) &&
145         "Malformed !callback metadata var-arg flag");
146
147  if (VarArgFlagAsCM->getValue()->isNullValue())
148    return;
149
150  // Add all variadic arguments at the end.
151  for (unsigned u = Callee->arg_size(); u < NumCallOperands; u++)
152    CI.ParameterEncoding.push_back(u);
153}
154