1//=--- RegUsageInfoPropagate.cpp - Register Usage Informartion Propagation --=//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// This pass is required to take advantage of the interprocedural register
10/// allocation infrastructure.
11///
12/// This pass iterates through MachineInstrs in a given MachineFunction and at
13/// each callsite queries RegisterUsageInfo for RegMask (calculated based on
14/// actual register allocation) of the callee function, if the RegMask detail
15/// is available then this pass will update the RegMask of the call instruction.
16/// This updated RegMask will be used by the register allocator while allocating
17/// the current MachineFunction.
18///
19//===----------------------------------------------------------------------===//
20
21#include "llvm/CodeGen/MachineBasicBlock.h"
22#include "llvm/CodeGen/MachineFunctionPass.h"
23#include "llvm/CodeGen/MachineFrameInfo.h"
24#include "llvm/CodeGen/MachineInstr.h"
25#include "llvm/CodeGen/MachineRegisterInfo.h"
26#include "llvm/CodeGen/Passes.h"
27#include "llvm/CodeGen/RegisterUsageInfo.h"
28#include "llvm/IR/Module.h"
29#include "llvm/PassAnalysisSupport.h"
30#include "llvm/Support/Debug.h"
31#include "llvm/Support/raw_ostream.h"
32#include "llvm/Target/TargetMachine.h"
33#include <map>
34#include <string>
35
36using namespace llvm;
37
38#define DEBUG_TYPE "ip-regalloc"
39
40#define RUIP_NAME "Register Usage Information Propagation"
41
42namespace {
43
44class RegUsageInfoPropagation : public MachineFunctionPass {
45public:
46  RegUsageInfoPropagation() : MachineFunctionPass(ID) {
47    PassRegistry &Registry = *PassRegistry::getPassRegistry();
48    initializeRegUsageInfoPropagationPass(Registry);
49  }
50
51  StringRef getPassName() const override { return RUIP_NAME; }
52
53  bool runOnMachineFunction(MachineFunction &MF) override;
54
55  void getAnalysisUsage(AnalysisUsage &AU) const override {
56    AU.addRequired<PhysicalRegisterUsageInfo>();
57    AU.setPreservesAll();
58    MachineFunctionPass::getAnalysisUsage(AU);
59  }
60
61  static char ID;
62
63private:
64  static void setRegMask(MachineInstr &MI, ArrayRef<uint32_t> RegMask) {
65    assert(RegMask.size() ==
66           MachineOperand::getRegMaskSize(MI.getParent()->getParent()
67                                          ->getRegInfo().getTargetRegisterInfo()
68                                          ->getNumRegs())
69           && "expected register mask size");
70    for (MachineOperand &MO : MI.operands()) {
71      if (MO.isRegMask())
72        MO.setRegMask(RegMask.data());
73    }
74  }
75};
76
77} // end of anonymous namespace
78
79INITIALIZE_PASS_BEGIN(RegUsageInfoPropagation, "reg-usage-propagation",
80                      RUIP_NAME, false, false)
81INITIALIZE_PASS_DEPENDENCY(PhysicalRegisterUsageInfo)
82INITIALIZE_PASS_END(RegUsageInfoPropagation, "reg-usage-propagation",
83                    RUIP_NAME, false, false)
84
85char RegUsageInfoPropagation::ID = 0;
86
87// Assumes call instructions have a single reference to a function.
88static const Function *findCalledFunction(const Module &M,
89                                          const MachineInstr &MI) {
90  for (const MachineOperand &MO : MI.operands()) {
91    if (MO.isGlobal())
92      return dyn_cast<const Function>(MO.getGlobal());
93
94    if (MO.isSymbol())
95      return M.getFunction(MO.getSymbolName());
96  }
97
98  return nullptr;
99}
100
101bool RegUsageInfoPropagation::runOnMachineFunction(MachineFunction &MF) {
102  const Module &M = *MF.getFunction().getParent();
103  PhysicalRegisterUsageInfo *PRUI = &getAnalysis<PhysicalRegisterUsageInfo>();
104
105  LLVM_DEBUG(dbgs() << " ++++++++++++++++++++ " << getPassName()
106                    << " ++++++++++++++++++++  \n");
107  LLVM_DEBUG(dbgs() << "MachineFunction : " << MF.getName() << "\n");
108
109  const MachineFrameInfo &MFI = MF.getFrameInfo();
110  if (!MFI.hasCalls() && !MFI.hasTailCall())
111    return false;
112
113  bool Changed = false;
114
115  for (MachineBasicBlock &MBB : MF) {
116    for (MachineInstr &MI : MBB) {
117      if (!MI.isCall())
118        continue;
119      LLVM_DEBUG(
120          dbgs()
121          << "Call Instruction Before Register Usage Info Propagation : \n");
122      LLVM_DEBUG(dbgs() << MI << "\n");
123
124      auto UpdateRegMask = [&](const Function &F) {
125        const ArrayRef<uint32_t> RegMask = PRUI->getRegUsageInfo(F);
126        if (RegMask.empty())
127          return;
128        setRegMask(MI, RegMask);
129        Changed = true;
130      };
131
132      if (const Function *F = findCalledFunction(M, MI)) {
133        if (F->isDefinitionExact()) {
134          UpdateRegMask(*F);
135        } else {
136          LLVM_DEBUG(dbgs() << "Function definition is not exact\n");
137        }
138      } else {
139        LLVM_DEBUG(dbgs() << "Failed to find call target function\n");
140      }
141
142      LLVM_DEBUG(
143          dbgs() << "Call Instruction After Register Usage Info Propagation : "
144                 << MI << '\n');
145    }
146  }
147
148  LLVM_DEBUG(
149      dbgs() << " +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++"
150                "++++++ \n");
151  return Changed;
152}
153
154FunctionPass *llvm::createRegUsageInfoPropPass() {
155  return new RegUsageInfoPropagation();
156}
157