1343171Sdim//===- NVPTXProxyRegErasure.cpp - NVPTX Proxy Register Instruction Erasure -==//
2343171Sdim//
3353358Sdim// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4353358Sdim// See https://llvm.org/LICENSE.txt for license information.
5353358Sdim// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6343171Sdim//
7343171Sdim//===----------------------------------------------------------------------===//
8343171Sdim//
9343171Sdim// The pass is needed to remove ProxyReg instructions and restore related
10343171Sdim// registers. The instructions were needed at instruction selection stage to
11343171Sdim// make sure that callseq_end nodes won't be removed as "dead nodes". This can
12343171Sdim// happen when we expand instructions into libcalls and the call site doesn't
13343171Sdim// care about the libcall chain. Call site cares about data flow only, and the
14343171Sdim// latest data flow node happens to be before callseq_end. Therefore the node
15343171Sdim// becomes dangling and "dead". The ProxyReg acts like an additional data flow
16343171Sdim// node *after* the callseq_end in the chain and ensures that everything will be
17343171Sdim// preserved.
18343171Sdim//
19343171Sdim//===----------------------------------------------------------------------===//
20343171Sdim
21343171Sdim#include "NVPTX.h"
22343171Sdim#include "llvm/CodeGen/MachineFunctionPass.h"
23343171Sdim#include "llvm/CodeGen/MachineInstrBuilder.h"
24343171Sdim#include "llvm/CodeGen/MachineRegisterInfo.h"
25343171Sdim#include "llvm/CodeGen/TargetInstrInfo.h"
26343171Sdim#include "llvm/CodeGen/TargetRegisterInfo.h"
27343171Sdim
28343171Sdimusing namespace llvm;
29343171Sdim
30343171Sdimnamespace llvm {
31343171Sdimvoid initializeNVPTXProxyRegErasurePass(PassRegistry &);
32343171Sdim}
33343171Sdim
34343171Sdimnamespace {
35343171Sdim
36343171Sdimstruct NVPTXProxyRegErasure : public MachineFunctionPass {
37343171Sdimpublic:
38343171Sdim  static char ID;
39343171Sdim  NVPTXProxyRegErasure() : MachineFunctionPass(ID) {
40343171Sdim    initializeNVPTXProxyRegErasurePass(*PassRegistry::getPassRegistry());
41343171Sdim  }
42343171Sdim
43343171Sdim  bool runOnMachineFunction(MachineFunction &MF) override;
44343171Sdim
45343171Sdim  StringRef getPassName() const override {
46343171Sdim    return "NVPTX Proxy Register Instruction Erasure";
47343171Sdim  }
48343171Sdim
49343171Sdim  void getAnalysisUsage(AnalysisUsage &AU) const override {
50343171Sdim    MachineFunctionPass::getAnalysisUsage(AU);
51343171Sdim  }
52343171Sdim
53343171Sdimprivate:
54343171Sdim  void replaceMachineInstructionUsage(MachineFunction &MF, MachineInstr &MI);
55343171Sdim
56343171Sdim  void replaceRegisterUsage(MachineInstr &Instr, MachineOperand &From,
57343171Sdim                            MachineOperand &To);
58343171Sdim};
59343171Sdim
60343171Sdim} // namespace
61343171Sdim
62343171Sdimchar NVPTXProxyRegErasure::ID = 0;
63343171Sdim
64343171SdimINITIALIZE_PASS(NVPTXProxyRegErasure, "nvptx-proxyreg-erasure", "NVPTX ProxyReg Erasure", false, false)
65343171Sdim
66343171Sdimbool NVPTXProxyRegErasure::runOnMachineFunction(MachineFunction &MF) {
67343171Sdim  SmallVector<MachineInstr *, 16> RemoveList;
68343171Sdim
69343171Sdim  for (auto &BB : MF) {
70343171Sdim    for (auto &MI : BB) {
71343171Sdim      switch (MI.getOpcode()) {
72343171Sdim      case NVPTX::ProxyRegI1:
73343171Sdim      case NVPTX::ProxyRegI16:
74343171Sdim      case NVPTX::ProxyRegI32:
75343171Sdim      case NVPTX::ProxyRegI64:
76343171Sdim      case NVPTX::ProxyRegF16:
77343171Sdim      case NVPTX::ProxyRegF16x2:
78343171Sdim      case NVPTX::ProxyRegF32:
79343171Sdim      case NVPTX::ProxyRegF64:
80343171Sdim        replaceMachineInstructionUsage(MF, MI);
81343171Sdim        RemoveList.push_back(&MI);
82343171Sdim        break;
83343171Sdim      }
84343171Sdim    }
85343171Sdim  }
86343171Sdim
87343171Sdim  for (auto *MI : RemoveList) {
88343171Sdim    MI->eraseFromParent();
89343171Sdim  }
90343171Sdim
91343171Sdim  return !RemoveList.empty();
92343171Sdim}
93343171Sdim
94343171Sdimvoid NVPTXProxyRegErasure::replaceMachineInstructionUsage(MachineFunction &MF,
95343171Sdim                                                          MachineInstr &MI) {
96343171Sdim  auto &InOp = *MI.uses().begin();
97343171Sdim  auto &OutOp = *MI.defs().begin();
98343171Sdim
99343171Sdim  assert(InOp.isReg() && "ProxyReg input operand should be a register.");
100343171Sdim  assert(OutOp.isReg() && "ProxyReg output operand should be a register.");
101343171Sdim
102343171Sdim  for (auto &BB : MF) {
103343171Sdim    for (auto &I : BB) {
104343171Sdim      replaceRegisterUsage(I, OutOp, InOp);
105343171Sdim    }
106343171Sdim  }
107343171Sdim}
108343171Sdim
109343171Sdimvoid NVPTXProxyRegErasure::replaceRegisterUsage(MachineInstr &Instr,
110343171Sdim                                                MachineOperand &From,
111343171Sdim                                                MachineOperand &To) {
112343171Sdim  for (auto &Op : Instr.uses()) {
113343171Sdim    if (Op.isReg() && Op.getReg() == From.getReg()) {
114343171Sdim      Op.setReg(To.getReg());
115343171Sdim    }
116343171Sdim  }
117343171Sdim}
118343171Sdim
119343171SdimMachineFunctionPass *llvm::createNVPTXProxyRegErasurePass() {
120343171Sdim  return new NVPTXProxyRegErasure();
121343171Sdim}
122