MLxExpansionPass.cpp revision 360784
1//===-- MLxExpansionPass.cpp - Expand MLx instrs to avoid hazards ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Expand VFP / NEON floating point MLA / MLS instructions (each to a pair of
10// multiple and add / sub instructions) when special VMLx hazards are detected.
11//
12//===----------------------------------------------------------------------===//
13
14#include "ARM.h"
15#include "ARMBaseInstrInfo.h"
16#include "ARMSubtarget.h"
17#include "llvm/ADT/SmallPtrSet.h"
18#include "llvm/ADT/Statistic.h"
19#include "llvm/CodeGen/MachineFunctionPass.h"
20#include "llvm/CodeGen/MachineInstr.h"
21#include "llvm/CodeGen/MachineInstrBuilder.h"
22#include "llvm/CodeGen/MachineRegisterInfo.h"
23#include "llvm/CodeGen/TargetRegisterInfo.h"
24#include "llvm/Support/CommandLine.h"
25#include "llvm/Support/Debug.h"
26#include "llvm/Support/raw_ostream.h"
27using namespace llvm;
28
29#define DEBUG_TYPE "mlx-expansion"
30
31static cl::opt<bool>
32ForceExapnd("expand-all-fp-mlx", cl::init(false), cl::Hidden);
33static cl::opt<unsigned>
34ExpandLimit("expand-limit", cl::init(~0U), cl::Hidden);
35
36STATISTIC(NumExpand, "Number of fp MLA / MLS instructions expanded");
37
38namespace {
39  struct MLxExpansion : public MachineFunctionPass {
40    static char ID;
41    MLxExpansion() : MachineFunctionPass(ID) {}
42
43    bool runOnMachineFunction(MachineFunction &Fn) override;
44
45    StringRef getPassName() const override {
46      return "ARM MLA / MLS expansion pass";
47    }
48
49  private:
50    const ARMBaseInstrInfo *TII;
51    const TargetRegisterInfo *TRI;
52    MachineRegisterInfo *MRI;
53
54    bool isLikeA9;
55    bool isSwift;
56    unsigned MIIdx;
57    MachineInstr* LastMIs[4];
58    SmallPtrSet<MachineInstr*, 4> IgnoreStall;
59
60    void clearStack();
61    void pushStack(MachineInstr *MI);
62    MachineInstr *getAccDefMI(MachineInstr *MI) const;
63    unsigned getDefReg(MachineInstr *MI) const;
64    bool hasLoopHazard(MachineInstr *MI) const;
65    bool hasRAWHazard(unsigned Reg, MachineInstr *MI) const;
66    bool FindMLxHazard(MachineInstr *MI);
67    void ExpandFPMLxInstruction(MachineBasicBlock &MBB, MachineInstr *MI,
68                                unsigned MulOpc, unsigned AddSubOpc,
69                                bool NegAcc, bool HasLane);
70    bool ExpandFPMLxInstructions(MachineBasicBlock &MBB);
71  };
72  char MLxExpansion::ID = 0;
73}
74
75void MLxExpansion::clearStack() {
76  std::fill(LastMIs, LastMIs + 4, nullptr);
77  MIIdx = 0;
78}
79
80void MLxExpansion::pushStack(MachineInstr *MI) {
81  LastMIs[MIIdx] = MI;
82  if (++MIIdx == 4)
83    MIIdx = 0;
84}
85
86MachineInstr *MLxExpansion::getAccDefMI(MachineInstr *MI) const {
87  // Look past COPY and INSERT_SUBREG instructions to find the
88  // real definition MI. This is important for _sfp instructions.
89  Register Reg = MI->getOperand(1).getReg();
90  if (Register::isPhysicalRegister(Reg))
91    return nullptr;
92
93  MachineBasicBlock *MBB = MI->getParent();
94  MachineInstr *DefMI = MRI->getVRegDef(Reg);
95  while (true) {
96    if (DefMI->getParent() != MBB)
97      break;
98    if (DefMI->isCopyLike()) {
99      Reg = DefMI->getOperand(1).getReg();
100      if (Register::isVirtualRegister(Reg)) {
101        DefMI = MRI->getVRegDef(Reg);
102        continue;
103      }
104    } else if (DefMI->isInsertSubreg()) {
105      Reg = DefMI->getOperand(2).getReg();
106      if (Register::isVirtualRegister(Reg)) {
107        DefMI = MRI->getVRegDef(Reg);
108        continue;
109      }
110    }
111    break;
112  }
113  return DefMI;
114}
115
116unsigned MLxExpansion::getDefReg(MachineInstr *MI) const {
117  Register Reg = MI->getOperand(0).getReg();
118  if (Register::isPhysicalRegister(Reg) || !MRI->hasOneNonDBGUse(Reg))
119    return Reg;
120
121  MachineBasicBlock *MBB = MI->getParent();
122  MachineInstr *UseMI = &*MRI->use_instr_nodbg_begin(Reg);
123  if (UseMI->getParent() != MBB)
124    return Reg;
125
126  while (UseMI->isCopy() || UseMI->isInsertSubreg()) {
127    Reg = UseMI->getOperand(0).getReg();
128    if (Register::isPhysicalRegister(Reg) || !MRI->hasOneNonDBGUse(Reg))
129      return Reg;
130    UseMI = &*MRI->use_instr_nodbg_begin(Reg);
131    if (UseMI->getParent() != MBB)
132      return Reg;
133  }
134
135  return Reg;
136}
137
138/// hasLoopHazard - Check whether an MLx instruction is chained to itself across
139/// a single-MBB loop.
140bool MLxExpansion::hasLoopHazard(MachineInstr *MI) const {
141  Register Reg = MI->getOperand(1).getReg();
142  if (Register::isPhysicalRegister(Reg))
143    return false;
144
145  MachineBasicBlock *MBB = MI->getParent();
146  MachineInstr *DefMI = MRI->getVRegDef(Reg);
147  while (true) {
148outer_continue:
149    if (DefMI->getParent() != MBB)
150      break;
151
152    if (DefMI->isPHI()) {
153      for (unsigned i = 1, e = DefMI->getNumOperands(); i < e; i += 2) {
154        if (DefMI->getOperand(i + 1).getMBB() == MBB) {
155          Register SrcReg = DefMI->getOperand(i).getReg();
156          if (Register::isVirtualRegister(SrcReg)) {
157            DefMI = MRI->getVRegDef(SrcReg);
158            goto outer_continue;
159          }
160        }
161      }
162    } else if (DefMI->isCopyLike()) {
163      Reg = DefMI->getOperand(1).getReg();
164      if (Register::isVirtualRegister(Reg)) {
165        DefMI = MRI->getVRegDef(Reg);
166        continue;
167      }
168    } else if (DefMI->isInsertSubreg()) {
169      Reg = DefMI->getOperand(2).getReg();
170      if (Register::isVirtualRegister(Reg)) {
171        DefMI = MRI->getVRegDef(Reg);
172        continue;
173      }
174    }
175
176    break;
177  }
178
179  return DefMI == MI;
180}
181
182bool MLxExpansion::hasRAWHazard(unsigned Reg, MachineInstr *MI) const {
183  // FIXME: Detect integer instructions properly.
184  const MCInstrDesc &MCID = MI->getDesc();
185  unsigned Domain = MCID.TSFlags & ARMII::DomainMask;
186  if (MI->mayStore())
187    return false;
188  unsigned Opcode = MCID.getOpcode();
189  if (Opcode == ARM::VMOVRS || Opcode == ARM::VMOVRRD)
190    return false;
191  if ((Domain & ARMII::DomainVFP) || (Domain & ARMII::DomainNEON))
192    return MI->readsRegister(Reg, TRI);
193  return false;
194}
195
196static bool isFpMulInstruction(unsigned Opcode) {
197  switch (Opcode) {
198  case ARM::VMULS:
199  case ARM::VMULfd:
200  case ARM::VMULfq:
201  case ARM::VMULD:
202  case ARM::VMULslfd:
203  case ARM::VMULslfq:
204    return true;
205  default:
206    return false;
207  }
208}
209
210bool MLxExpansion::FindMLxHazard(MachineInstr *MI) {
211  if (NumExpand >= ExpandLimit)
212    return false;
213
214  if (ForceExapnd)
215    return true;
216
217  MachineInstr *DefMI = getAccDefMI(MI);
218  if (TII->isFpMLxInstruction(DefMI->getOpcode())) {
219    // r0 = vmla
220    // r3 = vmla r0, r1, r2
221    // takes 16 - 17 cycles
222    //
223    // r0 = vmla
224    // r4 = vmul r1, r2
225    // r3 = vadd r0, r4
226    // takes about 14 - 15 cycles even with vmul stalling for 4 cycles.
227    IgnoreStall.insert(DefMI);
228    return true;
229  }
230
231  // On Swift, we mostly care about hazards from multiplication instructions
232  // writing the accumulator and the pipelining of loop iterations by out-of-
233  // order execution.
234  if (isSwift)
235    return isFpMulInstruction(DefMI->getOpcode()) || hasLoopHazard(MI);
236
237  if (IgnoreStall.count(MI))
238    return false;
239
240  // If a VMLA.F is followed by an VADD.F or VMUL.F with no RAW hazard, the
241  // VADD.F or VMUL.F will stall 4 cycles before issue. The 4 cycle stall
242  // preserves the in-order retirement of the instructions.
243  // Look at the next few instructions, if *most* of them can cause hazards,
244  // then the scheduler can't *fix* this, we'd better break up the VMLA.
245  unsigned Limit1 = isLikeA9 ? 1 : 4;
246  unsigned Limit2 = isLikeA9 ? 1 : 4;
247  for (unsigned i = 1; i <= 4; ++i) {
248    int Idx = ((int)MIIdx - i + 4) % 4;
249    MachineInstr *NextMI = LastMIs[Idx];
250    if (!NextMI)
251      continue;
252
253    if (TII->canCauseFpMLxStall(NextMI->getOpcode())) {
254      if (i <= Limit1)
255        return true;
256    }
257
258    // Look for VMLx RAW hazard.
259    if (i <= Limit2 && hasRAWHazard(getDefReg(MI), NextMI))
260      return true;
261  }
262
263  return false;
264}
265
266/// ExpandFPMLxInstructions - Expand a MLA / MLS instruction into a pair
267/// of MUL + ADD / SUB instructions.
268void
269MLxExpansion::ExpandFPMLxInstruction(MachineBasicBlock &MBB, MachineInstr *MI,
270                                     unsigned MulOpc, unsigned AddSubOpc,
271                                     bool NegAcc, bool HasLane) {
272  Register DstReg = MI->getOperand(0).getReg();
273  bool DstDead = MI->getOperand(0).isDead();
274  Register AccReg = MI->getOperand(1).getReg();
275  Register Src1Reg = MI->getOperand(2).getReg();
276  Register Src2Reg = MI->getOperand(3).getReg();
277  bool Src1Kill = MI->getOperand(2).isKill();
278  bool Src2Kill = MI->getOperand(3).isKill();
279  unsigned LaneImm = HasLane ? MI->getOperand(4).getImm() : 0;
280  unsigned NextOp = HasLane ? 5 : 4;
281  ARMCC::CondCodes Pred = (ARMCC::CondCodes)MI->getOperand(NextOp).getImm();
282  Register PredReg = MI->getOperand(++NextOp).getReg();
283
284  const MCInstrDesc &MCID1 = TII->get(MulOpc);
285  const MCInstrDesc &MCID2 = TII->get(AddSubOpc);
286  const MachineFunction &MF = *MI->getParent()->getParent();
287  Register TmpReg =
288      MRI->createVirtualRegister(TII->getRegClass(MCID1, 0, TRI, MF));
289
290  MachineInstrBuilder MIB = BuildMI(MBB, MI, MI->getDebugLoc(), MCID1, TmpReg)
291    .addReg(Src1Reg, getKillRegState(Src1Kill))
292    .addReg(Src2Reg, getKillRegState(Src2Kill));
293  if (HasLane)
294    MIB.addImm(LaneImm);
295  MIB.addImm(Pred).addReg(PredReg);
296
297  MIB = BuildMI(MBB, MI, MI->getDebugLoc(), MCID2)
298    .addReg(DstReg, getDefRegState(true) | getDeadRegState(DstDead));
299
300  if (NegAcc) {
301    bool AccKill = MRI->hasOneNonDBGUse(AccReg);
302    MIB.addReg(TmpReg, getKillRegState(true))
303       .addReg(AccReg, getKillRegState(AccKill));
304  } else {
305    MIB.addReg(AccReg).addReg(TmpReg, getKillRegState(true));
306  }
307  MIB.addImm(Pred).addReg(PredReg);
308
309  LLVM_DEBUG({
310    dbgs() << "Expanding: " << *MI;
311    dbgs() << "  to:\n";
312    MachineBasicBlock::iterator MII = MI;
313    MII = std::prev(MII);
314    MachineInstr &MI2 = *MII;
315    MII = std::prev(MII);
316    MachineInstr &MI1 = *MII;
317    dbgs() << "    " << MI1;
318    dbgs() << "    " << MI2;
319  });
320
321  MI->eraseFromParent();
322  ++NumExpand;
323}
324
325bool MLxExpansion::ExpandFPMLxInstructions(MachineBasicBlock &MBB) {
326  bool Changed = false;
327
328  clearStack();
329  IgnoreStall.clear();
330
331  unsigned Skip = 0;
332  MachineBasicBlock::reverse_iterator MII = MBB.rbegin(), E = MBB.rend();
333  while (MII != E) {
334    MachineInstr *MI = &*MII++;
335
336    if (MI->isPosition() || MI->isImplicitDef() || MI->isCopy())
337      continue;
338
339    const MCInstrDesc &MCID = MI->getDesc();
340    if (MI->isBarrier()) {
341      clearStack();
342      Skip = 0;
343      continue;
344    }
345
346    unsigned Domain = MCID.TSFlags & ARMII::DomainMask;
347    if (Domain == ARMII::DomainGeneral) {
348      if (++Skip == 2)
349        // Assume dual issues of non-VFP / NEON instructions.
350        pushStack(nullptr);
351    } else {
352      Skip = 0;
353
354      unsigned MulOpc, AddSubOpc;
355      bool NegAcc, HasLane;
356      if (!TII->isFpMLxInstruction(MCID.getOpcode(),
357                                   MulOpc, AddSubOpc, NegAcc, HasLane) ||
358          !FindMLxHazard(MI))
359        pushStack(MI);
360      else {
361        ExpandFPMLxInstruction(MBB, MI, MulOpc, AddSubOpc, NegAcc, HasLane);
362        Changed = true;
363      }
364    }
365  }
366
367  return Changed;
368}
369
370bool MLxExpansion::runOnMachineFunction(MachineFunction &Fn) {
371  if (skipFunction(Fn.getFunction()))
372    return false;
373
374  TII = static_cast<const ARMBaseInstrInfo *>(Fn.getSubtarget().getInstrInfo());
375  TRI = Fn.getSubtarget().getRegisterInfo();
376  MRI = &Fn.getRegInfo();
377  const ARMSubtarget *STI = &Fn.getSubtarget<ARMSubtarget>();
378  if (!STI->expandMLx())
379    return false;
380  isLikeA9 = STI->isLikeA9() || STI->isSwift();
381  isSwift = STI->isSwift();
382
383  bool Modified = false;
384  for (MachineBasicBlock &MBB : Fn)
385    Modified |= ExpandFPMLxInstructions(MBB);
386
387  return Modified;
388}
389
390FunctionPass *llvm::createMLxExpansionPass() {
391  return new MLxExpansion();
392}
393