1//===-------------- BPFMIPeephole.cpp - MI Peephole Cleanups  -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass performs peephole optimizations to cleanup ugly code sequences at
10// MachineInstruction layer.
11//
12// Currently, there are two optimizations implemented:
13//  - One pre-RA MachineSSA pass to eliminate type promotion sequences, those
14//    zero extend 32-bit subregisters to 64-bit registers, if the compiler
15//    could prove the subregisters is defined by 32-bit operations in which
16//    case the upper half of the underlying 64-bit registers were zeroed
17//    implicitly.
18//
19//  - One post-RA PreEmit pass to do final cleanup on some redundant
20//    instructions generated due to bad RA on subregister.
21//===----------------------------------------------------------------------===//
22
23#include "BPF.h"
24#include "BPFInstrInfo.h"
25#include "BPFTargetMachine.h"
26#include "llvm/ADT/Statistic.h"
27#include "llvm/CodeGen/MachineInstrBuilder.h"
28#include "llvm/CodeGen/MachineRegisterInfo.h"
29#include "llvm/Support/Debug.h"
30#include <set>
31
32using namespace llvm;
33
34#define DEBUG_TYPE "bpf-mi-zext-elim"
35
36STATISTIC(ZExtElemNum, "Number of zero extension shifts eliminated");
37
38namespace {
39
40struct BPFMIPeephole : public MachineFunctionPass {
41
42  static char ID;
43  const BPFInstrInfo *TII;
44  MachineFunction *MF;
45  MachineRegisterInfo *MRI;
46
47  BPFMIPeephole() : MachineFunctionPass(ID) {
48    initializeBPFMIPeepholePass(*PassRegistry::getPassRegistry());
49  }
50
51private:
52  // Initialize class variables.
53  void initialize(MachineFunction &MFParm);
54
55  bool isCopyFrom32Def(MachineInstr *CopyMI);
56  bool isInsnFrom32Def(MachineInstr *DefInsn);
57  bool isPhiFrom32Def(MachineInstr *MovMI);
58  bool isMovFrom32Def(MachineInstr *MovMI);
59  bool eliminateZExtSeq(void);
60  bool eliminateZExt(void);
61
62  std::set<MachineInstr *> PhiInsns;
63
64public:
65
66  // Main entry point for this pass.
67  bool runOnMachineFunction(MachineFunction &MF) override {
68    if (skipFunction(MF.getFunction()))
69      return false;
70
71    initialize(MF);
72
73    // First try to eliminate (zext, lshift, rshift) and then
74    // try to eliminate zext.
75    bool ZExtSeqExist, ZExtExist;
76    ZExtSeqExist = eliminateZExtSeq();
77    ZExtExist = eliminateZExt();
78    return ZExtSeqExist || ZExtExist;
79  }
80};
81
82// Initialize class variables.
83void BPFMIPeephole::initialize(MachineFunction &MFParm) {
84  MF = &MFParm;
85  MRI = &MF->getRegInfo();
86  TII = MF->getSubtarget<BPFSubtarget>().getInstrInfo();
87  LLVM_DEBUG(dbgs() << "*** BPF MachineSSA ZEXT Elim peephole pass ***\n\n");
88}
89
90bool BPFMIPeephole::isCopyFrom32Def(MachineInstr *CopyMI)
91{
92  MachineOperand &opnd = CopyMI->getOperand(1);
93
94  if (!opnd.isReg())
95    return false;
96
97  // Return false if getting value from a 32bit physical register.
98  // Most likely, this physical register is aliased to
99  // function call return value or current function parameters.
100  Register Reg = opnd.getReg();
101  if (!Register::isVirtualRegister(Reg))
102    return false;
103
104  if (MRI->getRegClass(Reg) == &BPF::GPRRegClass)
105    return false;
106
107  MachineInstr *DefInsn = MRI->getVRegDef(Reg);
108  if (!isInsnFrom32Def(DefInsn))
109    return false;
110
111  return true;
112}
113
114bool BPFMIPeephole::isPhiFrom32Def(MachineInstr *PhiMI)
115{
116  for (unsigned i = 1, e = PhiMI->getNumOperands(); i < e; i += 2) {
117    MachineOperand &opnd = PhiMI->getOperand(i);
118
119    if (!opnd.isReg())
120      return false;
121
122    MachineInstr *PhiDef = MRI->getVRegDef(opnd.getReg());
123    if (!PhiDef)
124      return false;
125    if (PhiDef->isPHI()) {
126      if (PhiInsns.find(PhiDef) != PhiInsns.end())
127        return false;
128      PhiInsns.insert(PhiDef);
129      if (!isPhiFrom32Def(PhiDef))
130        return false;
131    }
132    if (PhiDef->getOpcode() == BPF::COPY && !isCopyFrom32Def(PhiDef))
133      return false;
134  }
135
136  return true;
137}
138
139// The \p DefInsn instruction defines a virtual register.
140bool BPFMIPeephole::isInsnFrom32Def(MachineInstr *DefInsn)
141{
142  if (!DefInsn)
143    return false;
144
145  if (DefInsn->isPHI()) {
146    if (PhiInsns.find(DefInsn) != PhiInsns.end())
147      return false;
148    PhiInsns.insert(DefInsn);
149    if (!isPhiFrom32Def(DefInsn))
150      return false;
151  } else if (DefInsn->getOpcode() == BPF::COPY) {
152    if (!isCopyFrom32Def(DefInsn))
153      return false;
154  }
155
156  return true;
157}
158
159bool BPFMIPeephole::isMovFrom32Def(MachineInstr *MovMI)
160{
161  MachineInstr *DefInsn = MRI->getVRegDef(MovMI->getOperand(1).getReg());
162
163  LLVM_DEBUG(dbgs() << "  Def of Mov Src:");
164  LLVM_DEBUG(DefInsn->dump());
165
166  PhiInsns.clear();
167  if (!isInsnFrom32Def(DefInsn))
168    return false;
169
170  LLVM_DEBUG(dbgs() << "  One ZExt elim sequence identified.\n");
171
172  return true;
173}
174
175bool BPFMIPeephole::eliminateZExtSeq(void) {
176  MachineInstr* ToErase = nullptr;
177  bool Eliminated = false;
178
179  for (MachineBasicBlock &MBB : *MF) {
180    for (MachineInstr &MI : MBB) {
181      // If the previous instruction was marked for elimination, remove it now.
182      if (ToErase) {
183        ToErase->eraseFromParent();
184        ToErase = nullptr;
185      }
186
187      // Eliminate the 32-bit to 64-bit zero extension sequence when possible.
188      //
189      //   MOV_32_64 rB, wA
190      //   SLL_ri    rB, rB, 32
191      //   SRL_ri    rB, rB, 32
192      if (MI.getOpcode() == BPF::SRL_ri &&
193          MI.getOperand(2).getImm() == 32) {
194        Register DstReg = MI.getOperand(0).getReg();
195        Register ShfReg = MI.getOperand(1).getReg();
196        MachineInstr *SllMI = MRI->getVRegDef(ShfReg);
197
198        LLVM_DEBUG(dbgs() << "Starting SRL found:");
199        LLVM_DEBUG(MI.dump());
200
201        if (!SllMI ||
202            SllMI->isPHI() ||
203            SllMI->getOpcode() != BPF::SLL_ri ||
204            SllMI->getOperand(2).getImm() != 32)
205          continue;
206
207        LLVM_DEBUG(dbgs() << "  SLL found:");
208        LLVM_DEBUG(SllMI->dump());
209
210        MachineInstr *MovMI = MRI->getVRegDef(SllMI->getOperand(1).getReg());
211        if (!MovMI ||
212            MovMI->isPHI() ||
213            MovMI->getOpcode() != BPF::MOV_32_64)
214          continue;
215
216        LLVM_DEBUG(dbgs() << "  Type cast Mov found:");
217        LLVM_DEBUG(MovMI->dump());
218
219        Register SubReg = MovMI->getOperand(1).getReg();
220        if (!isMovFrom32Def(MovMI)) {
221          LLVM_DEBUG(dbgs()
222                     << "  One ZExt elim sequence failed qualifying elim.\n");
223          continue;
224        }
225
226        BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(BPF::SUBREG_TO_REG), DstReg)
227          .addImm(0).addReg(SubReg).addImm(BPF::sub_32);
228
229        SllMI->eraseFromParent();
230        MovMI->eraseFromParent();
231        // MI is the right shift, we can't erase it in it's own iteration.
232        // Mark it to ToErase, and erase in the next iteration.
233        ToErase = &MI;
234        ZExtElemNum++;
235        Eliminated = true;
236      }
237    }
238  }
239
240  return Eliminated;
241}
242
243bool BPFMIPeephole::eliminateZExt(void) {
244  MachineInstr* ToErase = nullptr;
245  bool Eliminated = false;
246
247  for (MachineBasicBlock &MBB : *MF) {
248    for (MachineInstr &MI : MBB) {
249      // If the previous instruction was marked for elimination, remove it now.
250      if (ToErase) {
251        ToErase->eraseFromParent();
252        ToErase = nullptr;
253      }
254
255      if (MI.getOpcode() != BPF::MOV_32_64)
256        continue;
257
258      // Eliminate MOV_32_64 if possible.
259      //   MOV_32_64 rA, wB
260      //
261      // If wB has been zero extended, replace it with a SUBREG_TO_REG.
262      // This is to workaround BPF programs where pkt->{data, data_end}
263      // is encoded as u32, but actually the verifier populates them
264      // as 64bit pointer. The MOV_32_64 will zero out the top 32 bits.
265      LLVM_DEBUG(dbgs() << "Candidate MOV_32_64 instruction:");
266      LLVM_DEBUG(MI.dump());
267
268      if (!isMovFrom32Def(&MI))
269        continue;
270
271      LLVM_DEBUG(dbgs() << "Removing the MOV_32_64 instruction\n");
272
273      Register dst = MI.getOperand(0).getReg();
274      Register src = MI.getOperand(1).getReg();
275
276      // Build a SUBREG_TO_REG instruction.
277      BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(BPF::SUBREG_TO_REG), dst)
278        .addImm(0).addReg(src).addImm(BPF::sub_32);
279
280      ToErase = &MI;
281      Eliminated = true;
282    }
283  }
284
285  return Eliminated;
286}
287
288} // end default namespace
289
290INITIALIZE_PASS(BPFMIPeephole, DEBUG_TYPE,
291                "BPF MachineSSA Peephole Optimization For ZEXT Eliminate",
292                false, false)
293
294char BPFMIPeephole::ID = 0;
295FunctionPass* llvm::createBPFMIPeepholePass() { return new BPFMIPeephole(); }
296
297STATISTIC(RedundantMovElemNum, "Number of redundant moves eliminated");
298
299namespace {
300
301struct BPFMIPreEmitPeephole : public MachineFunctionPass {
302
303  static char ID;
304  MachineFunction *MF;
305  const TargetRegisterInfo *TRI;
306
307  BPFMIPreEmitPeephole() : MachineFunctionPass(ID) {
308    initializeBPFMIPreEmitPeepholePass(*PassRegistry::getPassRegistry());
309  }
310
311private:
312  // Initialize class variables.
313  void initialize(MachineFunction &MFParm);
314
315  bool eliminateRedundantMov(void);
316
317public:
318
319  // Main entry point for this pass.
320  bool runOnMachineFunction(MachineFunction &MF) override {
321    if (skipFunction(MF.getFunction()))
322      return false;
323
324    initialize(MF);
325
326    return eliminateRedundantMov();
327  }
328};
329
330// Initialize class variables.
331void BPFMIPreEmitPeephole::initialize(MachineFunction &MFParm) {
332  MF = &MFParm;
333  TRI = MF->getSubtarget<BPFSubtarget>().getRegisterInfo();
334  LLVM_DEBUG(dbgs() << "*** BPF PreEmit peephole pass ***\n\n");
335}
336
337bool BPFMIPreEmitPeephole::eliminateRedundantMov(void) {
338  MachineInstr* ToErase = nullptr;
339  bool Eliminated = false;
340
341  for (MachineBasicBlock &MBB : *MF) {
342    for (MachineInstr &MI : MBB) {
343      // If the previous instruction was marked for elimination, remove it now.
344      if (ToErase) {
345        LLVM_DEBUG(dbgs() << "  Redundant Mov Eliminated:");
346        LLVM_DEBUG(ToErase->dump());
347        ToErase->eraseFromParent();
348        ToErase = nullptr;
349      }
350
351      // Eliminate identical move:
352      //
353      //   MOV rA, rA
354      //
355      // Note that we cannot remove
356      //   MOV_32_64  rA, wA
357      //   MOV_rr_32  wA, wA
358      // as these two instructions having side effects, zeroing out
359      // top 32 bits of rA.
360      unsigned Opcode = MI.getOpcode();
361      if (Opcode == BPF::MOV_rr) {
362        Register dst = MI.getOperand(0).getReg();
363        Register src = MI.getOperand(1).getReg();
364
365        if (dst != src)
366          continue;
367
368        ToErase = &MI;
369        RedundantMovElemNum++;
370        Eliminated = true;
371      }
372    }
373  }
374
375  return Eliminated;
376}
377
378} // end default namespace
379
380INITIALIZE_PASS(BPFMIPreEmitPeephole, "bpf-mi-pemit-peephole",
381                "BPF PreEmit Peephole Optimization", false, false)
382
383char BPFMIPreEmitPeephole::ID = 0;
384FunctionPass* llvm::createBPFMIPreEmitPeepholePass()
385{
386  return new BPFMIPreEmitPeephole();
387}
388
389STATISTIC(TruncElemNum, "Number of truncation eliminated");
390
391namespace {
392
393struct BPFMIPeepholeTruncElim : public MachineFunctionPass {
394
395  static char ID;
396  const BPFInstrInfo *TII;
397  MachineFunction *MF;
398  MachineRegisterInfo *MRI;
399
400  BPFMIPeepholeTruncElim() : MachineFunctionPass(ID) {
401    initializeBPFMIPeepholeTruncElimPass(*PassRegistry::getPassRegistry());
402  }
403
404private:
405  // Initialize class variables.
406  void initialize(MachineFunction &MFParm);
407
408  bool eliminateTruncSeq(void);
409
410public:
411
412  // Main entry point for this pass.
413  bool runOnMachineFunction(MachineFunction &MF) override {
414    if (skipFunction(MF.getFunction()))
415      return false;
416
417    initialize(MF);
418
419    return eliminateTruncSeq();
420  }
421};
422
423static bool TruncSizeCompatible(int TruncSize, unsigned opcode)
424{
425  if (TruncSize == 1)
426    return opcode == BPF::LDB || opcode == BPF::LDB32;
427
428  if (TruncSize == 2)
429    return opcode == BPF::LDH || opcode == BPF::LDH32;
430
431  if (TruncSize == 4)
432    return opcode == BPF::LDW || opcode == BPF::LDW32;
433
434  return false;
435}
436
437// Initialize class variables.
438void BPFMIPeepholeTruncElim::initialize(MachineFunction &MFParm) {
439  MF = &MFParm;
440  MRI = &MF->getRegInfo();
441  TII = MF->getSubtarget<BPFSubtarget>().getInstrInfo();
442  LLVM_DEBUG(dbgs() << "*** BPF MachineSSA TRUNC Elim peephole pass ***\n\n");
443}
444
445// Reg truncating is often the result of 8/16/32bit->64bit or
446// 8/16bit->32bit conversion. If the reg value is loaded with
447// masked byte width, the AND operation can be removed since
448// BPF LOAD already has zero extension.
449//
450// This also solved a correctness issue.
451// In BPF socket-related program, e.g., __sk_buff->{data, data_end}
452// are 32-bit registers, but later on, kernel verifier will rewrite
453// it with 64-bit value. Therefore, truncating the value after the
454// load will result in incorrect code.
455bool BPFMIPeepholeTruncElim::eliminateTruncSeq(void) {
456  MachineInstr* ToErase = nullptr;
457  bool Eliminated = false;
458
459  for (MachineBasicBlock &MBB : *MF) {
460    for (MachineInstr &MI : MBB) {
461      // The second insn to remove if the eliminate candidate is a pair.
462      MachineInstr *MI2 = nullptr;
463      Register DstReg, SrcReg;
464      MachineInstr *DefMI;
465      int TruncSize = -1;
466
467      // If the previous instruction was marked for elimination, remove it now.
468      if (ToErase) {
469        ToErase->eraseFromParent();
470        ToErase = nullptr;
471      }
472
473      // AND A, 0xFFFFFFFF will be turned into SLL/SRL pair due to immediate
474      // for BPF ANDI is i32, and this case only happens on ALU64.
475      if (MI.getOpcode() == BPF::SRL_ri &&
476          MI.getOperand(2).getImm() == 32) {
477        SrcReg = MI.getOperand(1).getReg();
478        MI2 = MRI->getVRegDef(SrcReg);
479        DstReg = MI.getOperand(0).getReg();
480
481        if (!MI2 ||
482            MI2->getOpcode() != BPF::SLL_ri ||
483            MI2->getOperand(2).getImm() != 32)
484          continue;
485
486        // Update SrcReg.
487        SrcReg = MI2->getOperand(1).getReg();
488        DefMI = MRI->getVRegDef(SrcReg);
489        if (DefMI)
490          TruncSize = 4;
491      } else if (MI.getOpcode() == BPF::AND_ri ||
492                 MI.getOpcode() == BPF::AND_ri_32) {
493        SrcReg = MI.getOperand(1).getReg();
494        DstReg = MI.getOperand(0).getReg();
495        DefMI = MRI->getVRegDef(SrcReg);
496
497        if (!DefMI)
498          continue;
499
500        int64_t imm = MI.getOperand(2).getImm();
501        if (imm == 0xff)
502          TruncSize = 1;
503        else if (imm == 0xffff)
504          TruncSize = 2;
505      }
506
507      if (TruncSize == -1)
508        continue;
509
510      // The definition is PHI node, check all inputs.
511      if (DefMI->isPHI()) {
512        bool CheckFail = false;
513
514        for (unsigned i = 1, e = DefMI->getNumOperands(); i < e; i += 2) {
515          MachineOperand &opnd = DefMI->getOperand(i);
516          if (!opnd.isReg()) {
517            CheckFail = true;
518            break;
519          }
520
521          MachineInstr *PhiDef = MRI->getVRegDef(opnd.getReg());
522          if (!PhiDef || PhiDef->isPHI() ||
523              !TruncSizeCompatible(TruncSize, PhiDef->getOpcode())) {
524            CheckFail = true;
525            break;
526          }
527        }
528
529        if (CheckFail)
530          continue;
531      } else if (!TruncSizeCompatible(TruncSize, DefMI->getOpcode())) {
532        continue;
533      }
534
535      BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(BPF::MOV_rr), DstReg)
536              .addReg(SrcReg);
537
538      if (MI2)
539        MI2->eraseFromParent();
540
541      // Mark it to ToErase, and erase in the next iteration.
542      ToErase = &MI;
543      TruncElemNum++;
544      Eliminated = true;
545    }
546  }
547
548  return Eliminated;
549}
550
551} // end default namespace
552
553INITIALIZE_PASS(BPFMIPeepholeTruncElim, "bpf-mi-trunc-elim",
554                "BPF MachineSSA Peephole Optimization For TRUNC Eliminate",
555                false, false)
556
557char BPFMIPeepholeTruncElim::ID = 0;
558FunctionPass* llvm::createBPFMIPeepholeTruncElimPass()
559{
560  return new BPFMIPeepholeTruncElim();
561}
562