1//===-------------- BPFMIPeephole.cpp - MI Peephole Cleanups  -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass performs peephole optimizations to cleanup ugly code sequences at
10// MachineInstruction layer.
11//
12// Currently, there are two optimizations implemented:
13//  - One pre-RA MachineSSA pass to eliminate type promotion sequences, those
14//    zero extend 32-bit subregisters to 64-bit registers, if the compiler
15//    could prove the subregisters is defined by 32-bit operations in which
16//    case the upper half of the underlying 64-bit registers were zeroed
17//    implicitly.
18//
19//  - One post-RA PreEmit pass to do final cleanup on some redundant
20//    instructions generated due to bad RA on subregister.
21//===----------------------------------------------------------------------===//
22
23#include "BPF.h"
24#include "BPFInstrInfo.h"
25#include "BPFTargetMachine.h"
26#include "llvm/ADT/Statistic.h"
27#include "llvm/CodeGen/MachineFunctionPass.h"
28#include "llvm/CodeGen/MachineInstrBuilder.h"
29#include "llvm/CodeGen/MachineRegisterInfo.h"
30#include "llvm/Support/Debug.h"
31#include <set>
32
33using namespace llvm;
34
35#define DEBUG_TYPE "bpf-mi-zext-elim"
36
37static cl::opt<int> GotolAbsLowBound("gotol-abs-low-bound", cl::Hidden,
38  cl::init(INT16_MAX >> 1), cl::desc("Specify gotol lower bound"));
39
40STATISTIC(ZExtElemNum, "Number of zero extension shifts eliminated");
41
42namespace {
43
44struct BPFMIPeephole : public MachineFunctionPass {
45
46  static char ID;
47  const BPFInstrInfo *TII;
48  MachineFunction *MF;
49  MachineRegisterInfo *MRI;
50
51  BPFMIPeephole() : MachineFunctionPass(ID) {
52    initializeBPFMIPeepholePass(*PassRegistry::getPassRegistry());
53  }
54
55private:
56  // Initialize class variables.
57  void initialize(MachineFunction &MFParm);
58
59  bool isCopyFrom32Def(MachineInstr *CopyMI);
60  bool isInsnFrom32Def(MachineInstr *DefInsn);
61  bool isPhiFrom32Def(MachineInstr *MovMI);
62  bool isMovFrom32Def(MachineInstr *MovMI);
63  bool eliminateZExtSeq();
64  bool eliminateZExt();
65
66  std::set<MachineInstr *> PhiInsns;
67
68public:
69
70  // Main entry point for this pass.
71  bool runOnMachineFunction(MachineFunction &MF) override {
72    if (skipFunction(MF.getFunction()))
73      return false;
74
75    initialize(MF);
76
77    // First try to eliminate (zext, lshift, rshift) and then
78    // try to eliminate zext.
79    bool ZExtSeqExist, ZExtExist;
80    ZExtSeqExist = eliminateZExtSeq();
81    ZExtExist = eliminateZExt();
82    return ZExtSeqExist || ZExtExist;
83  }
84};
85
86// Initialize class variables.
87void BPFMIPeephole::initialize(MachineFunction &MFParm) {
88  MF = &MFParm;
89  MRI = &MF->getRegInfo();
90  TII = MF->getSubtarget<BPFSubtarget>().getInstrInfo();
91  LLVM_DEBUG(dbgs() << "*** BPF MachineSSA ZEXT Elim peephole pass ***\n\n");
92}
93
94bool BPFMIPeephole::isCopyFrom32Def(MachineInstr *CopyMI)
95{
96  MachineOperand &opnd = CopyMI->getOperand(1);
97
98  if (!opnd.isReg())
99    return false;
100
101  // Return false if getting value from a 32bit physical register.
102  // Most likely, this physical register is aliased to
103  // function call return value or current function parameters.
104  Register Reg = opnd.getReg();
105  if (!Reg.isVirtual())
106    return false;
107
108  if (MRI->getRegClass(Reg) == &BPF::GPRRegClass)
109    return false;
110
111  MachineInstr *DefInsn = MRI->getVRegDef(Reg);
112  if (!isInsnFrom32Def(DefInsn))
113    return false;
114
115  return true;
116}
117
118bool BPFMIPeephole::isPhiFrom32Def(MachineInstr *PhiMI)
119{
120  for (unsigned i = 1, e = PhiMI->getNumOperands(); i < e; i += 2) {
121    MachineOperand &opnd = PhiMI->getOperand(i);
122
123    if (!opnd.isReg())
124      return false;
125
126    MachineInstr *PhiDef = MRI->getVRegDef(opnd.getReg());
127    if (!PhiDef)
128      return false;
129    if (PhiDef->isPHI()) {
130      if (!PhiInsns.insert(PhiDef).second)
131        return false;
132      if (!isPhiFrom32Def(PhiDef))
133        return false;
134    }
135    if (PhiDef->getOpcode() == BPF::COPY && !isCopyFrom32Def(PhiDef))
136      return false;
137  }
138
139  return true;
140}
141
142// The \p DefInsn instruction defines a virtual register.
143bool BPFMIPeephole::isInsnFrom32Def(MachineInstr *DefInsn)
144{
145  if (!DefInsn)
146    return false;
147
148  if (DefInsn->isPHI()) {
149    if (!PhiInsns.insert(DefInsn).second)
150      return false;
151    if (!isPhiFrom32Def(DefInsn))
152      return false;
153  } else if (DefInsn->getOpcode() == BPF::COPY) {
154    if (!isCopyFrom32Def(DefInsn))
155      return false;
156  }
157
158  return true;
159}
160
161bool BPFMIPeephole::isMovFrom32Def(MachineInstr *MovMI)
162{
163  MachineInstr *DefInsn = MRI->getVRegDef(MovMI->getOperand(1).getReg());
164
165  LLVM_DEBUG(dbgs() << "  Def of Mov Src:");
166  LLVM_DEBUG(DefInsn->dump());
167
168  PhiInsns.clear();
169  if (!isInsnFrom32Def(DefInsn))
170    return false;
171
172  LLVM_DEBUG(dbgs() << "  One ZExt elim sequence identified.\n");
173
174  return true;
175}
176
177bool BPFMIPeephole::eliminateZExtSeq() {
178  MachineInstr* ToErase = nullptr;
179  bool Eliminated = false;
180
181  for (MachineBasicBlock &MBB : *MF) {
182    for (MachineInstr &MI : MBB) {
183      // If the previous instruction was marked for elimination, remove it now.
184      if (ToErase) {
185        ToErase->eraseFromParent();
186        ToErase = nullptr;
187      }
188
189      // Eliminate the 32-bit to 64-bit zero extension sequence when possible.
190      //
191      //   MOV_32_64 rB, wA
192      //   SLL_ri    rB, rB, 32
193      //   SRL_ri    rB, rB, 32
194      if (MI.getOpcode() == BPF::SRL_ri &&
195          MI.getOperand(2).getImm() == 32) {
196        Register DstReg = MI.getOperand(0).getReg();
197        Register ShfReg = MI.getOperand(1).getReg();
198        MachineInstr *SllMI = MRI->getVRegDef(ShfReg);
199
200        LLVM_DEBUG(dbgs() << "Starting SRL found:");
201        LLVM_DEBUG(MI.dump());
202
203        if (!SllMI ||
204            SllMI->isPHI() ||
205            SllMI->getOpcode() != BPF::SLL_ri ||
206            SllMI->getOperand(2).getImm() != 32)
207          continue;
208
209        LLVM_DEBUG(dbgs() << "  SLL found:");
210        LLVM_DEBUG(SllMI->dump());
211
212        MachineInstr *MovMI = MRI->getVRegDef(SllMI->getOperand(1).getReg());
213        if (!MovMI ||
214            MovMI->isPHI() ||
215            MovMI->getOpcode() != BPF::MOV_32_64)
216          continue;
217
218        LLVM_DEBUG(dbgs() << "  Type cast Mov found:");
219        LLVM_DEBUG(MovMI->dump());
220
221        Register SubReg = MovMI->getOperand(1).getReg();
222        if (!isMovFrom32Def(MovMI)) {
223          LLVM_DEBUG(dbgs()
224                     << "  One ZExt elim sequence failed qualifying elim.\n");
225          continue;
226        }
227
228        BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(BPF::SUBREG_TO_REG), DstReg)
229          .addImm(0).addReg(SubReg).addImm(BPF::sub_32);
230
231        SllMI->eraseFromParent();
232        MovMI->eraseFromParent();
233        // MI is the right shift, we can't erase it in it's own iteration.
234        // Mark it to ToErase, and erase in the next iteration.
235        ToErase = &MI;
236        ZExtElemNum++;
237        Eliminated = true;
238      }
239    }
240  }
241
242  return Eliminated;
243}
244
245bool BPFMIPeephole::eliminateZExt() {
246  MachineInstr* ToErase = nullptr;
247  bool Eliminated = false;
248
249  for (MachineBasicBlock &MBB : *MF) {
250    for (MachineInstr &MI : MBB) {
251      // If the previous instruction was marked for elimination, remove it now.
252      if (ToErase) {
253        ToErase->eraseFromParent();
254        ToErase = nullptr;
255      }
256
257      if (MI.getOpcode() != BPF::MOV_32_64)
258        continue;
259
260      // Eliminate MOV_32_64 if possible.
261      //   MOV_32_64 rA, wB
262      //
263      // If wB has been zero extended, replace it with a SUBREG_TO_REG.
264      // This is to workaround BPF programs where pkt->{data, data_end}
265      // is encoded as u32, but actually the verifier populates them
266      // as 64bit pointer. The MOV_32_64 will zero out the top 32 bits.
267      LLVM_DEBUG(dbgs() << "Candidate MOV_32_64 instruction:");
268      LLVM_DEBUG(MI.dump());
269
270      if (!isMovFrom32Def(&MI))
271        continue;
272
273      LLVM_DEBUG(dbgs() << "Removing the MOV_32_64 instruction\n");
274
275      Register dst = MI.getOperand(0).getReg();
276      Register src = MI.getOperand(1).getReg();
277
278      // Build a SUBREG_TO_REG instruction.
279      BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(BPF::SUBREG_TO_REG), dst)
280        .addImm(0).addReg(src).addImm(BPF::sub_32);
281
282      ToErase = &MI;
283      Eliminated = true;
284    }
285  }
286
287  return Eliminated;
288}
289
290} // end default namespace
291
292INITIALIZE_PASS(BPFMIPeephole, DEBUG_TYPE,
293                "BPF MachineSSA Peephole Optimization For ZEXT Eliminate",
294                false, false)
295
296char BPFMIPeephole::ID = 0;
297FunctionPass* llvm::createBPFMIPeepholePass() { return new BPFMIPeephole(); }
298
299STATISTIC(RedundantMovElemNum, "Number of redundant moves eliminated");
300
301namespace {
302
303struct BPFMIPreEmitPeephole : public MachineFunctionPass {
304
305  static char ID;
306  MachineFunction *MF;
307  const TargetRegisterInfo *TRI;
308  const BPFInstrInfo *TII;
309  bool SupportGotol;
310
311  BPFMIPreEmitPeephole() : MachineFunctionPass(ID) {
312    initializeBPFMIPreEmitPeepholePass(*PassRegistry::getPassRegistry());
313  }
314
315private:
316  // Initialize class variables.
317  void initialize(MachineFunction &MFParm);
318
319  bool in16BitRange(int Num);
320  bool eliminateRedundantMov();
321  bool adjustBranch();
322
323public:
324
325  // Main entry point for this pass.
326  bool runOnMachineFunction(MachineFunction &MF) override {
327    if (skipFunction(MF.getFunction()))
328      return false;
329
330    initialize(MF);
331
332    bool Changed;
333    Changed = eliminateRedundantMov();
334    if (SupportGotol)
335      Changed = adjustBranch() || Changed;
336    return Changed;
337  }
338};
339
340// Initialize class variables.
341void BPFMIPreEmitPeephole::initialize(MachineFunction &MFParm) {
342  MF = &MFParm;
343  TII = MF->getSubtarget<BPFSubtarget>().getInstrInfo();
344  TRI = MF->getSubtarget<BPFSubtarget>().getRegisterInfo();
345  SupportGotol = MF->getSubtarget<BPFSubtarget>().hasGotol();
346  LLVM_DEBUG(dbgs() << "*** BPF PreEmit peephole pass ***\n\n");
347}
348
349bool BPFMIPreEmitPeephole::eliminateRedundantMov() {
350  MachineInstr* ToErase = nullptr;
351  bool Eliminated = false;
352
353  for (MachineBasicBlock &MBB : *MF) {
354    for (MachineInstr &MI : MBB) {
355      // If the previous instruction was marked for elimination, remove it now.
356      if (ToErase) {
357        LLVM_DEBUG(dbgs() << "  Redundant Mov Eliminated:");
358        LLVM_DEBUG(ToErase->dump());
359        ToErase->eraseFromParent();
360        ToErase = nullptr;
361      }
362
363      // Eliminate identical move:
364      //
365      //   MOV rA, rA
366      //
367      // Note that we cannot remove
368      //   MOV_32_64  rA, wA
369      //   MOV_rr_32  wA, wA
370      // as these two instructions having side effects, zeroing out
371      // top 32 bits of rA.
372      unsigned Opcode = MI.getOpcode();
373      if (Opcode == BPF::MOV_rr) {
374        Register dst = MI.getOperand(0).getReg();
375        Register src = MI.getOperand(1).getReg();
376
377        if (dst != src)
378          continue;
379
380        ToErase = &MI;
381        RedundantMovElemNum++;
382        Eliminated = true;
383      }
384    }
385  }
386
387  return Eliminated;
388}
389
390bool BPFMIPreEmitPeephole::in16BitRange(int Num) {
391  // Well, the cut-off is not precisely at 16bit range since
392  // new codes are added during the transformation. So let us
393  // a little bit conservative.
394  return Num >= -GotolAbsLowBound && Num <= GotolAbsLowBound;
395}
396
397// Before cpu=v4, only 16bit branch target offset (-0x8000 to 0x7fff)
398// is supported for both unconditional (JMP) and condition (JEQ, JSGT,
399// etc.) branches. In certain cases, e.g., full unrolling, the branch
400// target offset might exceed 16bit range. If this happens, the llvm
401// will generate incorrect code as the offset is truncated to 16bit.
402//
403// To fix this rare case, a new insn JMPL is introduced. This new
404// insn supports supports 32bit branch target offset. The compiler
405// does not use this insn during insn selection. Rather, BPF backend
406// will estimate the branch target offset and do JMP -> JMPL and
407// JEQ -> JEQ + JMPL conversion if the estimated branch target offset
408// is beyond 16bit.
409bool BPFMIPreEmitPeephole::adjustBranch() {
410  bool Changed = false;
411  int CurrNumInsns = 0;
412  DenseMap<MachineBasicBlock *, int> SoFarNumInsns;
413  DenseMap<MachineBasicBlock *, MachineBasicBlock *> FollowThroughBB;
414  std::vector<MachineBasicBlock *> MBBs;
415
416  MachineBasicBlock *PrevBB = nullptr;
417  for (MachineBasicBlock &MBB : *MF) {
418    // MBB.size() is the number of insns in this basic block, including some
419    // debug info, e.g., DEBUG_VALUE, so we may over-count a little bit.
420    // Typically we have way more normal insns than DEBUG_VALUE insns.
421    // Also, if we indeed need to convert conditional branch like JEQ to
422    // JEQ + JMPL, we actually introduced some new insns like below.
423    CurrNumInsns += (int)MBB.size();
424    SoFarNumInsns[&MBB] = CurrNumInsns;
425    if (PrevBB != nullptr)
426      FollowThroughBB[PrevBB] = &MBB;
427    PrevBB = &MBB;
428    // A list of original BBs to make later traveral easier.
429    MBBs.push_back(&MBB);
430  }
431  FollowThroughBB[PrevBB] = nullptr;
432
433  for (unsigned i = 0; i < MBBs.size(); i++) {
434    // We have four cases here:
435    //  (1). no terminator, simple follow through.
436    //  (2). jmp to another bb.
437    //  (3). conditional jmp to another bb or follow through.
438    //  (4). conditional jmp followed by an unconditional jmp.
439    MachineInstr *CondJmp = nullptr, *UncondJmp = nullptr;
440
441    MachineBasicBlock *MBB = MBBs[i];
442    for (MachineInstr &Term : MBB->terminators()) {
443      if (Term.isConditionalBranch()) {
444        assert(CondJmp == nullptr);
445        CondJmp = &Term;
446      } else if (Term.isUnconditionalBranch()) {
447        assert(UncondJmp == nullptr);
448        UncondJmp = &Term;
449      }
450    }
451
452    // (1). no terminator, simple follow through.
453    if (!CondJmp && !UncondJmp)
454      continue;
455
456    MachineBasicBlock *CondTargetBB, *JmpBB;
457    CurrNumInsns = SoFarNumInsns[MBB];
458
459    // (2). jmp to another bb.
460    if (!CondJmp && UncondJmp) {
461      JmpBB = UncondJmp->getOperand(0).getMBB();
462      if (in16BitRange(SoFarNumInsns[JmpBB] - JmpBB->size() - CurrNumInsns))
463        continue;
464
465      // replace this insn as a JMPL.
466      BuildMI(MBB, UncondJmp->getDebugLoc(), TII->get(BPF::JMPL)).addMBB(JmpBB);
467      UncondJmp->eraseFromParent();
468      Changed = true;
469      continue;
470    }
471
472    const BasicBlock *TermBB = MBB->getBasicBlock();
473    int Dist;
474
475    // (3). conditional jmp to another bb or follow through.
476    if (!UncondJmp) {
477      CondTargetBB = CondJmp->getOperand(2).getMBB();
478      MachineBasicBlock *FollowBB = FollowThroughBB[MBB];
479      Dist = SoFarNumInsns[CondTargetBB] - CondTargetBB->size() - CurrNumInsns;
480      if (in16BitRange(Dist))
481        continue;
482
483      // We have
484      //   B2: ...
485      //       if (cond) goto B5
486      //   B3: ...
487      // where B2 -> B5 is beyond 16bit range.
488      //
489      // We do not have 32bit cond jmp insn. So we try to do
490      // the following.
491      //   B2:     ...
492      //           if (cond) goto New_B1
493      //   New_B0  goto B3
494      //   New_B1: gotol B5
495      //   B3: ...
496      // Basically two new basic blocks are created.
497      MachineBasicBlock *New_B0 = MF->CreateMachineBasicBlock(TermBB);
498      MachineBasicBlock *New_B1 = MF->CreateMachineBasicBlock(TermBB);
499
500      // Insert New_B0 and New_B1 into function block list.
501      MachineFunction::iterator MBB_I  = ++MBB->getIterator();
502      MF->insert(MBB_I, New_B0);
503      MF->insert(MBB_I, New_B1);
504
505      // replace B2 cond jump
506      if (CondJmp->getOperand(1).isReg())
507        BuildMI(*MBB, MachineBasicBlock::iterator(*CondJmp), CondJmp->getDebugLoc(), TII->get(CondJmp->getOpcode()))
508            .addReg(CondJmp->getOperand(0).getReg())
509            .addReg(CondJmp->getOperand(1).getReg())
510            .addMBB(New_B1);
511      else
512        BuildMI(*MBB, MachineBasicBlock::iterator(*CondJmp), CondJmp->getDebugLoc(), TII->get(CondJmp->getOpcode()))
513            .addReg(CondJmp->getOperand(0).getReg())
514            .addImm(CondJmp->getOperand(1).getImm())
515            .addMBB(New_B1);
516
517      // it is possible that CondTargetBB and FollowBB are the same. But the
518      // above Dist checking should already filtered this case.
519      MBB->removeSuccessor(CondTargetBB);
520      MBB->removeSuccessor(FollowBB);
521      MBB->addSuccessor(New_B0);
522      MBB->addSuccessor(New_B1);
523
524      // Populate insns in New_B0 and New_B1.
525      BuildMI(New_B0, CondJmp->getDebugLoc(), TII->get(BPF::JMP)).addMBB(FollowBB);
526      BuildMI(New_B1, CondJmp->getDebugLoc(), TII->get(BPF::JMPL))
527          .addMBB(CondTargetBB);
528
529      New_B0->addSuccessor(FollowBB);
530      New_B1->addSuccessor(CondTargetBB);
531      CondJmp->eraseFromParent();
532      Changed = true;
533      continue;
534    }
535
536    //  (4). conditional jmp followed by an unconditional jmp.
537    CondTargetBB = CondJmp->getOperand(2).getMBB();
538    JmpBB = UncondJmp->getOperand(0).getMBB();
539
540    // We have
541    //   B2: ...
542    //       if (cond) goto B5
543    //       JMP B7
544    //   B3: ...
545    //
546    // If only B2->B5 is out of 16bit range, we can do
547    //   B2: ...
548    //       if (cond) goto new_B
549    //       JMP B7
550    //   New_B: gotol B5
551    //   B3: ...
552    //
553    // If only 'JMP B7' is out of 16bit range, we can replace
554    // 'JMP B7' with 'JMPL B7'.
555    //
556    // If both B2->B5 and 'JMP B7' is out of range, just do
557    // both the above transformations.
558    Dist = SoFarNumInsns[CondTargetBB] - CondTargetBB->size() - CurrNumInsns;
559    if (!in16BitRange(Dist)) {
560      MachineBasicBlock *New_B = MF->CreateMachineBasicBlock(TermBB);
561
562      // Insert New_B0 into function block list.
563      MF->insert(++MBB->getIterator(), New_B);
564
565      // replace B2 cond jump
566      if (CondJmp->getOperand(1).isReg())
567        BuildMI(*MBB, MachineBasicBlock::iterator(*CondJmp), CondJmp->getDebugLoc(), TII->get(CondJmp->getOpcode()))
568            .addReg(CondJmp->getOperand(0).getReg())
569            .addReg(CondJmp->getOperand(1).getReg())
570            .addMBB(New_B);
571      else
572        BuildMI(*MBB, MachineBasicBlock::iterator(*CondJmp), CondJmp->getDebugLoc(), TII->get(CondJmp->getOpcode()))
573            .addReg(CondJmp->getOperand(0).getReg())
574            .addImm(CondJmp->getOperand(1).getImm())
575            .addMBB(New_B);
576
577      if (CondTargetBB != JmpBB)
578        MBB->removeSuccessor(CondTargetBB);
579      MBB->addSuccessor(New_B);
580
581      // Populate insn in New_B.
582      BuildMI(New_B, CondJmp->getDebugLoc(), TII->get(BPF::JMPL)).addMBB(CondTargetBB);
583
584      New_B->addSuccessor(CondTargetBB);
585      CondJmp->eraseFromParent();
586      Changed = true;
587    }
588
589    if (!in16BitRange(SoFarNumInsns[JmpBB] - CurrNumInsns)) {
590      BuildMI(MBB, UncondJmp->getDebugLoc(), TII->get(BPF::JMPL)).addMBB(JmpBB);
591      UncondJmp->eraseFromParent();
592      Changed = true;
593    }
594  }
595
596  return Changed;
597}
598
599} // end default namespace
600
601INITIALIZE_PASS(BPFMIPreEmitPeephole, "bpf-mi-pemit-peephole",
602                "BPF PreEmit Peephole Optimization", false, false)
603
604char BPFMIPreEmitPeephole::ID = 0;
605FunctionPass* llvm::createBPFMIPreEmitPeepholePass()
606{
607  return new BPFMIPreEmitPeephole();
608}
609