//===---- X86CondBrFolding.cpp - optimize conditional branches ------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // This file defines a pass that optimizes condition branches on x86 by taking // advantage of the three-way conditional code generated by compare // instructions. // Currently, it tries to hoisting EQ and NE conditional branch to a dominant // conditional branch condition where the same EQ/NE conditional code is // computed. An example: // bb_0: // cmp %0, 19 // jg bb_1 // jmp bb_2 // bb_1: // cmp %0, 40 // jg bb_3 // jmp bb_4 // bb_4: // cmp %0, 20 // je bb_5 // jmp bb_6 // Here we could combine the two compares in bb_0 and bb_4 and have the // following code: // bb_0: // cmp %0, 20 // jg bb_1 // jl bb_2 // jmp bb_5 // bb_1: // cmp %0, 40 // jg bb_3 // jmp bb_6 // For the case of %0 == 20 (bb_5), we eliminate two jumps, and the control // height for bb_6 is also reduced. bb_4 is gone after the optimization. // // There are plenty of this code patterns, especially from the switch case // lowing where we generate compare of "pivot-1" for the inner nodes in the // binary search tree. //===----------------------------------------------------------------------===// #include "X86.h" #include "X86InstrInfo.h" #include "X86Subtarget.h" #include "llvm/ADT/Statistic.h" #include "llvm/CodeGen/MachineBranchProbabilityInfo.h" #include "llvm/CodeGen/MachineFunctionPass.h" #include "llvm/CodeGen/MachineInstrBuilder.h" #include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/Support/BranchProbability.h" using namespace llvm; #define DEBUG_TYPE "x86-condbr-folding" STATISTIC(NumFixedCondBrs, "Number of x86 condbr folded"); namespace { class X86CondBrFoldingPass : public MachineFunctionPass { public: X86CondBrFoldingPass() : MachineFunctionPass(ID) { } StringRef getPassName() const override { return "X86 CondBr Folding"; } bool runOnMachineFunction(MachineFunction &MF) override; void getAnalysisUsage(AnalysisUsage &AU) const override { MachineFunctionPass::getAnalysisUsage(AU); AU.addRequired(); } public: static char ID; }; } // namespace char X86CondBrFoldingPass::ID = 0; INITIALIZE_PASS(X86CondBrFoldingPass, "X86CondBrFolding", "X86CondBrFolding", false, false) FunctionPass *llvm::createX86CondBrFolding() { return new X86CondBrFoldingPass(); } namespace { // A class the stores the auxiliary information for each MBB. struct TargetMBBInfo { MachineBasicBlock *TBB; MachineBasicBlock *FBB; MachineInstr *BrInstr; MachineInstr *CmpInstr; X86::CondCode BranchCode; unsigned SrcReg; int CmpValue; bool Modified; bool CmpBrOnly; }; // A class that optimizes the conditional branch by hoisting and merge CondCode. class X86CondBrFolding { public: X86CondBrFolding(const X86InstrInfo *TII, const MachineBranchProbabilityInfo *MBPI, MachineFunction &MF) : TII(TII), MBPI(MBPI), MF(MF) {} bool optimize(); private: const X86InstrInfo *TII; const MachineBranchProbabilityInfo *MBPI; MachineFunction &MF; std::vector> MBBInfos; SmallVector RemoveList; void optimizeCondBr(MachineBasicBlock &MBB, SmallVectorImpl &BranchPath); void replaceBrDest(MachineBasicBlock *MBB, MachineBasicBlock *OrigDest, MachineBasicBlock *NewDest); void fixupModifiedCond(MachineBasicBlock *MBB); std::unique_ptr analyzeMBB(MachineBasicBlock &MBB); static bool analyzeCompare(const MachineInstr &MI, unsigned &SrcReg, int &CmpValue); bool findPath(MachineBasicBlock *MBB, SmallVectorImpl &BranchPath); TargetMBBInfo *getMBBInfo(MachineBasicBlock *MBB) const { return MBBInfos[MBB->getNumber()].get(); } }; } // namespace // Find a valid path that we can reuse the CondCode. // The resulted path (if return true) is stored in BranchPath. // Return value: // false: is no valid path is found. // true: a valid path is found and the targetBB can be reached. bool X86CondBrFolding::findPath( MachineBasicBlock *MBB, SmallVectorImpl &BranchPath) { TargetMBBInfo *MBBInfo = getMBBInfo(MBB); assert(MBBInfo && "Expecting a candidate MBB"); int CmpValue = MBBInfo->CmpValue; MachineBasicBlock *PredMBB = *MBB->pred_begin(); MachineBasicBlock *SaveMBB = MBB; while (PredMBB) { TargetMBBInfo *PredMBBInfo = getMBBInfo(PredMBB); if (!PredMBBInfo || PredMBBInfo->SrcReg != MBBInfo->SrcReg) return false; assert(SaveMBB == PredMBBInfo->TBB || SaveMBB == PredMBBInfo->FBB); bool IsFalseBranch = (SaveMBB == PredMBBInfo->FBB); X86::CondCode CC = PredMBBInfo->BranchCode; assert(CC == X86::COND_L || CC == X86::COND_G || CC == X86::COND_E); int PredCmpValue = PredMBBInfo->CmpValue; bool ValueCmpTrue = ((CmpValue < PredCmpValue && CC == X86::COND_L) || (CmpValue > PredCmpValue && CC == X86::COND_G) || (CmpValue == PredCmpValue && CC == X86::COND_E)); // Check if both the result of value compare and the branch target match. if (!(ValueCmpTrue ^ IsFalseBranch)) { LLVM_DEBUG(dbgs() << "Dead BB detected!\n"); return false; } BranchPath.push_back(PredMBB); // These are the conditions on which we could combine the compares. if ((CmpValue == PredCmpValue) || (CmpValue == PredCmpValue - 1 && CC == X86::COND_L) || (CmpValue == PredCmpValue + 1 && CC == X86::COND_G)) return true; // If PredMBB has more than on preds, or not a pure cmp and br, we bailout. if (PredMBB->pred_size() != 1 || !PredMBBInfo->CmpBrOnly) return false; SaveMBB = PredMBB; PredMBB = *PredMBB->pred_begin(); } return false; } // Fix up any PHI node in the successor of MBB. static void fixPHIsInSucc(MachineBasicBlock *MBB, MachineBasicBlock *OldMBB, MachineBasicBlock *NewMBB) { if (NewMBB == OldMBB) return; for (auto MI = MBB->instr_begin(), ME = MBB->instr_end(); MI != ME && MI->isPHI(); ++MI) for (unsigned i = 2, e = MI->getNumOperands() + 1; i != e; i += 2) { MachineOperand &MO = MI->getOperand(i); if (MO.getMBB() == OldMBB) MO.setMBB(NewMBB); } } // Utility function to set branch probability for edge MBB->SuccMBB. static inline bool setBranchProb(MachineBasicBlock *MBB, MachineBasicBlock *SuccMBB, BranchProbability Prob) { auto MBBI = std::find(MBB->succ_begin(), MBB->succ_end(), SuccMBB); if (MBBI == MBB->succ_end()) return false; MBB->setSuccProbability(MBBI, Prob); return true; } // Utility function to find the unconditional br instruction in MBB. static inline MachineBasicBlock::iterator findUncondBrI(MachineBasicBlock *MBB) { return std::find_if(MBB->begin(), MBB->end(), [](MachineInstr &MI) -> bool { return MI.getOpcode() == X86::JMP_1; }); } // Replace MBB's original successor, OrigDest, with NewDest. // Also update the MBBInfo for MBB. void X86CondBrFolding::replaceBrDest(MachineBasicBlock *MBB, MachineBasicBlock *OrigDest, MachineBasicBlock *NewDest) { TargetMBBInfo *MBBInfo = getMBBInfo(MBB); MachineInstr *BrMI; if (MBBInfo->TBB == OrigDest) { BrMI = MBBInfo->BrInstr; MachineInstrBuilder MIB = BuildMI(*MBB, BrMI, MBB->findDebugLoc(BrMI), TII->get(X86::JCC_1)) .addMBB(NewDest).addImm(MBBInfo->BranchCode); MBBInfo->TBB = NewDest; MBBInfo->BrInstr = MIB.getInstr(); } else { // Should be the unconditional jump stmt. MachineBasicBlock::iterator UncondBrI = findUncondBrI(MBB); BuildMI(*MBB, UncondBrI, MBB->findDebugLoc(UncondBrI), TII->get(X86::JMP_1)) .addMBB(NewDest); MBBInfo->FBB = NewDest; BrMI = &*UncondBrI; } fixPHIsInSucc(NewDest, OrigDest, MBB); BrMI->eraseFromParent(); MBB->addSuccessor(NewDest); setBranchProb(MBB, NewDest, MBPI->getEdgeProbability(MBB, OrigDest)); MBB->removeSuccessor(OrigDest); } // Change the CondCode and BrInstr according to MBBInfo. void X86CondBrFolding::fixupModifiedCond(MachineBasicBlock *MBB) { TargetMBBInfo *MBBInfo = getMBBInfo(MBB); if (!MBBInfo->Modified) return; MachineInstr *BrMI = MBBInfo->BrInstr; X86::CondCode CC = MBBInfo->BranchCode; MachineInstrBuilder MIB = BuildMI(*MBB, BrMI, MBB->findDebugLoc(BrMI), TII->get(X86::JCC_1)) .addMBB(MBBInfo->TBB).addImm(CC); BrMI->eraseFromParent(); MBBInfo->BrInstr = MIB.getInstr(); MachineBasicBlock::iterator UncondBrI = findUncondBrI(MBB); BuildMI(*MBB, UncondBrI, MBB->findDebugLoc(UncondBrI), TII->get(X86::JMP_1)) .addMBB(MBBInfo->FBB); MBB->erase(UncondBrI); MBBInfo->Modified = false; } // // Apply the transformation: // RootMBB -1-> ... PredMBB -3-> MBB -5-> TargetMBB // \-2-> \-4-> \-6-> FalseMBB // ==> // RootMBB -1-> ... PredMBB -7-> FalseMBB // TargetMBB <-8-/ \-2-> \-4-> // // Note that PredMBB and RootMBB could be the same. // And in the case of dead TargetMBB, we will not have TargetMBB and edge 8. // // There are some special handling where the RootMBB is COND_E in which case // we directly short-cycle the brinstr. // void X86CondBrFolding::optimizeCondBr( MachineBasicBlock &MBB, SmallVectorImpl &BranchPath) { X86::CondCode CC; TargetMBBInfo *MBBInfo = getMBBInfo(&MBB); assert(MBBInfo && "Expecting a candidate MBB"); MachineBasicBlock *TargetMBB = MBBInfo->TBB; BranchProbability TargetProb = MBPI->getEdgeProbability(&MBB, MBBInfo->TBB); // Forward the jump from MBB's predecessor to MBB's false target. MachineBasicBlock *PredMBB = BranchPath.front(); TargetMBBInfo *PredMBBInfo = getMBBInfo(PredMBB); assert(PredMBBInfo && "Expecting a candidate MBB"); if (PredMBBInfo->Modified) fixupModifiedCond(PredMBB); CC = PredMBBInfo->BranchCode; // Don't do this if depth of BranchPath is 1 and PredMBB is of COND_E. // We will short-cycle directly for this case. if (!(CC == X86::COND_E && BranchPath.size() == 1)) replaceBrDest(PredMBB, &MBB, MBBInfo->FBB); MachineBasicBlock *RootMBB = BranchPath.back(); TargetMBBInfo *RootMBBInfo = getMBBInfo(RootMBB); assert(RootMBBInfo && "Expecting a candidate MBB"); if (RootMBBInfo->Modified) fixupModifiedCond(RootMBB); CC = RootMBBInfo->BranchCode; if (CC != X86::COND_E) { MachineBasicBlock::iterator UncondBrI = findUncondBrI(RootMBB); // RootMBB: Cond jump to the original not-taken MBB. X86::CondCode NewCC; switch (CC) { case X86::COND_L: NewCC = X86::COND_G; break; case X86::COND_G: NewCC = X86::COND_L; break; default: llvm_unreachable("unexpected condtional code."); } BuildMI(*RootMBB, UncondBrI, RootMBB->findDebugLoc(UncondBrI), TII->get(X86::JCC_1)) .addMBB(RootMBBInfo->FBB).addImm(NewCC); // RootMBB: Jump to TargetMBB BuildMI(*RootMBB, UncondBrI, RootMBB->findDebugLoc(UncondBrI), TII->get(X86::JMP_1)) .addMBB(TargetMBB); RootMBB->addSuccessor(TargetMBB); fixPHIsInSucc(TargetMBB, &MBB, RootMBB); RootMBB->erase(UncondBrI); } else { replaceBrDest(RootMBB, RootMBBInfo->TBB, TargetMBB); } // Fix RootMBB's CmpValue to MBB's CmpValue to TargetMBB. Don't set Imm // directly. Move MBB's stmt to here as the opcode might be different. if (RootMBBInfo->CmpValue != MBBInfo->CmpValue) { MachineInstr *NewCmp = MBBInfo->CmpInstr; NewCmp->removeFromParent(); RootMBB->insert(RootMBBInfo->CmpInstr, NewCmp); RootMBBInfo->CmpInstr->eraseFromParent(); } // Fix branch Probabilities. auto fixBranchProb = [&](MachineBasicBlock *NextMBB) { BranchProbability Prob; for (auto &I : BranchPath) { MachineBasicBlock *ThisMBB = I; if (!ThisMBB->hasSuccessorProbabilities() || !ThisMBB->isSuccessor(NextMBB)) break; Prob = MBPI->getEdgeProbability(ThisMBB, NextMBB); if (Prob.isUnknown()) break; TargetProb = Prob * TargetProb; Prob = Prob - TargetProb; setBranchProb(ThisMBB, NextMBB, Prob); if (ThisMBB == RootMBB) { setBranchProb(ThisMBB, TargetMBB, TargetProb); } ThisMBB->normalizeSuccProbs(); if (ThisMBB == RootMBB) break; NextMBB = ThisMBB; } return true; }; if (CC != X86::COND_E && !TargetProb.isUnknown()) fixBranchProb(MBBInfo->FBB); if (CC != X86::COND_E) RemoveList.push_back(&MBB); // Invalidate MBBInfo just in case. MBBInfos[MBB.getNumber()] = nullptr; MBBInfos[RootMBB->getNumber()] = nullptr; LLVM_DEBUG(dbgs() << "After optimization:\nRootMBB is: " << *RootMBB << "\n"); if (BranchPath.size() > 1) LLVM_DEBUG(dbgs() << "PredMBB is: " << *(BranchPath[0]) << "\n"); } // Driver function for optimization: find the valid candidate and apply // the transformation. bool X86CondBrFolding::optimize() { bool Changed = false; LLVM_DEBUG(dbgs() << "***** X86CondBr Folding on Function: " << MF.getName() << " *****\n"); // Setup data structures. MBBInfos.resize(MF.getNumBlockIDs()); for (auto &MBB : MF) MBBInfos[MBB.getNumber()] = analyzeMBB(MBB); for (auto &MBB : MF) { TargetMBBInfo *MBBInfo = getMBBInfo(&MBB); if (!MBBInfo || !MBBInfo->CmpBrOnly) continue; if (MBB.pred_size() != 1) continue; LLVM_DEBUG(dbgs() << "Work on MBB." << MBB.getNumber() << " CmpValue: " << MBBInfo->CmpValue << "\n"); SmallVector BranchPath; if (!findPath(&MBB, BranchPath)) continue; #ifndef NDEBUG LLVM_DEBUG(dbgs() << "Found one path (len=" << BranchPath.size() << "):\n"); int Index = 1; LLVM_DEBUG(dbgs() << "Target MBB is: " << MBB << "\n"); for (auto I = BranchPath.rbegin(); I != BranchPath.rend(); ++I, ++Index) { MachineBasicBlock *PMBB = *I; TargetMBBInfo *PMBBInfo = getMBBInfo(PMBB); LLVM_DEBUG(dbgs() << "Path MBB (" << Index << " of " << BranchPath.size() << ") is " << *PMBB); LLVM_DEBUG(dbgs() << "CC=" << PMBBInfo->BranchCode << " Val=" << PMBBInfo->CmpValue << " CmpBrOnly=" << PMBBInfo->CmpBrOnly << "\n\n"); } #endif optimizeCondBr(MBB, BranchPath); Changed = true; } NumFixedCondBrs += RemoveList.size(); for (auto MBBI : RemoveList) { while (!MBBI->succ_empty()) MBBI->removeSuccessor(MBBI->succ_end() - 1); MBBI->eraseFromParent(); } return Changed; } // Analyze instructions that generate CondCode and extract information. bool X86CondBrFolding::analyzeCompare(const MachineInstr &MI, unsigned &SrcReg, int &CmpValue) { unsigned SrcRegIndex = 0; unsigned ValueIndex = 0; switch (MI.getOpcode()) { // TODO: handle test instructions. default: return false; case X86::CMP64ri32: case X86::CMP64ri8: case X86::CMP32ri: case X86::CMP32ri8: case X86::CMP16ri: case X86::CMP16ri8: case X86::CMP8ri: SrcRegIndex = 0; ValueIndex = 1; break; case X86::SUB64ri32: case X86::SUB64ri8: case X86::SUB32ri: case X86::SUB32ri8: case X86::SUB16ri: case X86::SUB16ri8: case X86::SUB8ri: SrcRegIndex = 1; ValueIndex = 2; break; } SrcReg = MI.getOperand(SrcRegIndex).getReg(); if (!MI.getOperand(ValueIndex).isImm()) return false; CmpValue = MI.getOperand(ValueIndex).getImm(); return true; } // Analyze a candidate MBB and set the extract all the information needed. // The valid candidate will have two successors. // It also should have a sequence of // Branch_instr, // CondBr, // UnCondBr. // Return TargetMBBInfo if MBB is a valid candidate and nullptr otherwise. std::unique_ptr X86CondBrFolding::analyzeMBB(MachineBasicBlock &MBB) { MachineBasicBlock *TBB; MachineBasicBlock *FBB; MachineInstr *BrInstr; MachineInstr *CmpInstr; X86::CondCode CC; unsigned SrcReg; int CmpValue; bool Modified; bool CmpBrOnly; if (MBB.succ_size() != 2) return nullptr; CmpBrOnly = true; FBB = TBB = nullptr; CmpInstr = nullptr; MachineBasicBlock::iterator I = MBB.end(); while (I != MBB.begin()) { --I; if (I->isDebugValue()) continue; if (I->getOpcode() == X86::JMP_1) { if (FBB) return nullptr; FBB = I->getOperand(0).getMBB(); continue; } if (I->isBranch()) { if (TBB) return nullptr; CC = X86::getCondFromBranch(*I); switch (CC) { default: return nullptr; case X86::COND_E: case X86::COND_L: case X86::COND_G: case X86::COND_NE: case X86::COND_LE: case X86::COND_GE: break; } TBB = I->getOperand(0).getMBB(); BrInstr = &*I; continue; } if (analyzeCompare(*I, SrcReg, CmpValue)) { if (CmpInstr) return nullptr; CmpInstr = &*I; continue; } CmpBrOnly = false; break; } if (!TBB || !FBB || !CmpInstr) return nullptr; // Simplify CondCode. Note this is only to simplify the findPath logic // and will not change the instruction here. switch (CC) { case X86::COND_NE: CC = X86::COND_E; std::swap(TBB, FBB); Modified = true; break; case X86::COND_LE: if (CmpValue == INT_MAX) return nullptr; CC = X86::COND_L; CmpValue += 1; Modified = true; break; case X86::COND_GE: if (CmpValue == INT_MIN) return nullptr; CC = X86::COND_G; CmpValue -= 1; Modified = true; break; default: Modified = false; break; } return std::make_unique(TargetMBBInfo{ TBB, FBB, BrInstr, CmpInstr, CC, SrcReg, CmpValue, Modified, CmpBrOnly}); } bool X86CondBrFoldingPass::runOnMachineFunction(MachineFunction &MF) { const X86Subtarget &ST = MF.getSubtarget(); if (!ST.threewayBranchProfitable()) return false; const X86InstrInfo *TII = ST.getInstrInfo(); const MachineBranchProbabilityInfo *MBPI = &getAnalysis(); X86CondBrFolding CondBr(TII, MBPI, MF); return CondBr.optimize(); }