1303231Sdim//=--- RegUsageInfoPropagate.cpp - Register Usage Informartion Propagation --=//
2303231Sdim//
3353358Sdim// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4353358Sdim// See https://llvm.org/LICENSE.txt for license information.
5353358Sdim// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6303231Sdim//
7303231Sdim//===----------------------------------------------------------------------===//
8303231Sdim///
9303231Sdim/// This pass is required to take advantage of the interprocedural register
10303231Sdim/// allocation infrastructure.
11303231Sdim///
12303231Sdim/// This pass iterates through MachineInstrs in a given MachineFunction and at
13303231Sdim/// each callsite queries RegisterUsageInfo for RegMask (calculated based on
14303231Sdim/// actual register allocation) of the callee function, if the RegMask detail
15303231Sdim/// is available then this pass will update the RegMask of the call instruction.
16303231Sdim/// This updated RegMask will be used by the register allocator while allocating
17303231Sdim/// the current MachineFunction.
18303231Sdim///
19303231Sdim//===----------------------------------------------------------------------===//
20303231Sdim
21303231Sdim#include "llvm/CodeGen/MachineBasicBlock.h"
22303231Sdim#include "llvm/CodeGen/MachineFunctionPass.h"
23327952Sdim#include "llvm/CodeGen/MachineFrameInfo.h"
24303231Sdim#include "llvm/CodeGen/MachineInstr.h"
25303231Sdim#include "llvm/CodeGen/MachineRegisterInfo.h"
26303231Sdim#include "llvm/CodeGen/Passes.h"
27303231Sdim#include "llvm/CodeGen/RegisterUsageInfo.h"
28303231Sdim#include "llvm/IR/Module.h"
29303231Sdim#include "llvm/PassAnalysisSupport.h"
30303231Sdim#include "llvm/Support/Debug.h"
31303231Sdim#include "llvm/Support/raw_ostream.h"
32303231Sdim#include "llvm/Target/TargetMachine.h"
33303231Sdim#include <map>
34303231Sdim#include <string>
35303231Sdim
36303231Sdimusing namespace llvm;
37303231Sdim
38303231Sdim#define DEBUG_TYPE "ip-regalloc"
39303231Sdim
40303231Sdim#define RUIP_NAME "Register Usage Information Propagation"
41303231Sdim
42303231Sdimnamespace {
43303231Sdim
44341825Sdimclass RegUsageInfoPropagation : public MachineFunctionPass {
45303231Sdimpublic:
46341825Sdim  RegUsageInfoPropagation() : MachineFunctionPass(ID) {
47303231Sdim    PassRegistry &Registry = *PassRegistry::getPassRegistry();
48341825Sdim    initializeRegUsageInfoPropagationPass(Registry);
49303231Sdim  }
50303231Sdim
51314564Sdim  StringRef getPassName() const override { return RUIP_NAME; }
52303231Sdim
53303231Sdim  bool runOnMachineFunction(MachineFunction &MF) override;
54303231Sdim
55341825Sdim  void getAnalysisUsage(AnalysisUsage &AU) const override {
56341825Sdim    AU.addRequired<PhysicalRegisterUsageInfo>();
57341825Sdim    AU.setPreservesAll();
58341825Sdim    MachineFunctionPass::getAnalysisUsage(AU);
59341825Sdim  }
60303231Sdim
61303231Sdim  static char ID;
62303231Sdim
63303231Sdimprivate:
64341825Sdim  static void setRegMask(MachineInstr &MI, ArrayRef<uint32_t> RegMask) {
65341825Sdim    assert(RegMask.size() ==
66341825Sdim           MachineOperand::getRegMaskSize(MI.getParent()->getParent()
67341825Sdim                                          ->getRegInfo().getTargetRegisterInfo()
68341825Sdim                                          ->getNumRegs())
69341825Sdim           && "expected register mask size");
70303231Sdim    for (MachineOperand &MO : MI.operands()) {
71303231Sdim      if (MO.isRegMask())
72341825Sdim        MO.setRegMask(RegMask.data());
73303231Sdim    }
74303231Sdim  }
75303231Sdim};
76341825Sdim
77303231Sdim} // end of anonymous namespace
78303231Sdim
79341825SdimINITIALIZE_PASS_BEGIN(RegUsageInfoPropagation, "reg-usage-propagation",
80303231Sdim                      RUIP_NAME, false, false)
81303231SdimINITIALIZE_PASS_DEPENDENCY(PhysicalRegisterUsageInfo)
82341825SdimINITIALIZE_PASS_END(RegUsageInfoPropagation, "reg-usage-propagation",
83303231Sdim                    RUIP_NAME, false, false)
84303231Sdim
85341825Sdimchar RegUsageInfoPropagation::ID = 0;
86303231Sdim
87327952Sdim// Assumes call instructions have a single reference to a function.
88341825Sdimstatic const Function *findCalledFunction(const Module &M,
89341825Sdim                                          const MachineInstr &MI) {
90341825Sdim  for (const MachineOperand &MO : MI.operands()) {
91327952Sdim    if (MO.isGlobal())
92341825Sdim      return dyn_cast<const Function>(MO.getGlobal());
93327952Sdim
94327952Sdim    if (MO.isSymbol())
95327952Sdim      return M.getFunction(MO.getSymbolName());
96327952Sdim  }
97327952Sdim
98327952Sdim  return nullptr;
99327952Sdim}
100327952Sdim
101341825Sdimbool RegUsageInfoPropagation::runOnMachineFunction(MachineFunction &MF) {
102341825Sdim  const Module &M = *MF.getFunction().getParent();
103303231Sdim  PhysicalRegisterUsageInfo *PRUI = &getAnalysis<PhysicalRegisterUsageInfo>();
104303231Sdim
105341825Sdim  LLVM_DEBUG(dbgs() << " ++++++++++++++++++++ " << getPassName()
106341825Sdim                    << " ++++++++++++++++++++  \n");
107341825Sdim  LLVM_DEBUG(dbgs() << "MachineFunction : " << MF.getName() << "\n");
108303231Sdim
109327952Sdim  const MachineFrameInfo &MFI = MF.getFrameInfo();
110327952Sdim  if (!MFI.hasCalls() && !MFI.hasTailCall())
111327952Sdim    return false;
112327952Sdim
113303231Sdim  bool Changed = false;
114303231Sdim
115303231Sdim  for (MachineBasicBlock &MBB : MF) {
116303231Sdim    for (MachineInstr &MI : MBB) {
117303231Sdim      if (!MI.isCall())
118303231Sdim        continue;
119341825Sdim      LLVM_DEBUG(
120341825Sdim          dbgs()
121341825Sdim          << "Call Instruction Before Register Usage Info Propagation : \n");
122341825Sdim      LLVM_DEBUG(dbgs() << MI << "\n");
123303231Sdim
124341825Sdim      auto UpdateRegMask = [&](const Function &F) {
125341825Sdim        const ArrayRef<uint32_t> RegMask = PRUI->getRegUsageInfo(F);
126341825Sdim        if (RegMask.empty())
127303231Sdim          return;
128341825Sdim        setRegMask(MI, RegMask);
129303231Sdim        Changed = true;
130303231Sdim      };
131303231Sdim
132341825Sdim      if (const Function *F = findCalledFunction(M, MI)) {
133360784Sdim        if (F->isDefinitionExact()) {
134360784Sdim          UpdateRegMask(*F);
135360784Sdim        } else {
136360784Sdim          LLVM_DEBUG(dbgs() << "Function definition is not exact\n");
137360784Sdim        }
138327952Sdim      } else {
139341825Sdim        LLVM_DEBUG(dbgs() << "Failed to find call target function\n");
140327952Sdim      }
141303231Sdim
142341825Sdim      LLVM_DEBUG(
143341825Sdim          dbgs() << "Call Instruction After Register Usage Info Propagation : "
144341825Sdim                 << MI << '\n');
145303231Sdim    }
146303231Sdim  }
147303231Sdim
148341825Sdim  LLVM_DEBUG(
149341825Sdim      dbgs() << " +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++"
150341825Sdim                "++++++ \n");
151303231Sdim  return Changed;
152303231Sdim}
153341825Sdim
154341825SdimFunctionPass *llvm::createRegUsageInfoPropPass() {
155341825Sdim  return new RegUsageInfoPropagation();
156341825Sdim}
157