1320374Sdim//===-- AArch64CondBrTuning.cpp --- Conditional branch tuning for AArch64 -===// 2320374Sdim// 3353358Sdim// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4353358Sdim// See https://llvm.org/LICENSE.txt for license information. 5353358Sdim// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6320374Sdim// 7320374Sdim//===----------------------------------------------------------------------===// 8320374Sdim/// \file 9320374Sdim/// This file contains a pass that transforms CBZ/CBNZ/TBZ/TBNZ instructions 10320374Sdim/// into a conditional branch (B.cond), when the NZCV flags can be set for 11320374Sdim/// "free". This is preferred on targets that have more flexibility when 12320374Sdim/// scheduling B.cond instructions as compared to CBZ/CBNZ/TBZ/TBNZ (assuming 13320374Sdim/// all other variables are equal). This can also reduce register pressure. 14320374Sdim/// 15320374Sdim/// A few examples: 16320374Sdim/// 17320374Sdim/// 1) add w8, w0, w1 -> cmn w0, w1 ; CMN is an alias of ADDS. 18320374Sdim/// cbz w8, .LBB_2 -> b.eq .LBB0_2 19320374Sdim/// 20320374Sdim/// 2) add w8, w0, w1 -> adds w8, w0, w1 ; w8 has multiple uses. 21320374Sdim/// cbz w8, .LBB1_2 -> b.eq .LBB1_2 22320374Sdim/// 23320374Sdim/// 3) sub w8, w0, w1 -> subs w8, w0, w1 ; w8 has multiple uses. 24320572Sdim/// tbz w8, #31, .LBB6_2 -> b.pl .LBB6_2 25320374Sdim/// 26320374Sdim//===----------------------------------------------------------------------===// 27320374Sdim 28320374Sdim#include "AArch64.h" 29320374Sdim#include "AArch64Subtarget.h" 30320374Sdim#include "llvm/CodeGen/MachineFunction.h" 31320374Sdim#include "llvm/CodeGen/MachineFunctionPass.h" 32320374Sdim#include "llvm/CodeGen/MachineInstrBuilder.h" 33320374Sdim#include "llvm/CodeGen/MachineRegisterInfo.h" 34320374Sdim#include "llvm/CodeGen/Passes.h" 35327952Sdim#include "llvm/CodeGen/TargetInstrInfo.h" 36327952Sdim#include "llvm/CodeGen/TargetRegisterInfo.h" 37327952Sdim#include "llvm/CodeGen/TargetSubtargetInfo.h" 38320374Sdim#include "llvm/Support/Debug.h" 39320374Sdim#include "llvm/Support/raw_ostream.h" 40320374Sdim 41320374Sdimusing namespace llvm; 42320374Sdim 43320374Sdim#define DEBUG_TYPE "aarch64-cond-br-tuning" 44320374Sdim#define AARCH64_CONDBR_TUNING_NAME "AArch64 Conditional Branch Tuning" 45320374Sdim 46320374Sdimnamespace { 47320374Sdimclass AArch64CondBrTuning : public MachineFunctionPass { 48320374Sdim const AArch64InstrInfo *TII; 49320374Sdim const TargetRegisterInfo *TRI; 50320374Sdim 51320374Sdim MachineRegisterInfo *MRI; 52320374Sdim 53320374Sdimpublic: 54320374Sdim static char ID; 55320374Sdim AArch64CondBrTuning() : MachineFunctionPass(ID) { 56320374Sdim initializeAArch64CondBrTuningPass(*PassRegistry::getPassRegistry()); 57320374Sdim } 58320374Sdim void getAnalysisUsage(AnalysisUsage &AU) const override; 59320374Sdim bool runOnMachineFunction(MachineFunction &MF) override; 60320374Sdim StringRef getPassName() const override { return AARCH64_CONDBR_TUNING_NAME; } 61320374Sdim 62320374Sdimprivate: 63320374Sdim MachineInstr *getOperandDef(const MachineOperand &MO); 64320374Sdim MachineInstr *convertToFlagSetting(MachineInstr &MI, bool IsFlagSetting); 65320374Sdim MachineInstr *convertToCondBr(MachineInstr &MI); 66320374Sdim bool tryToTuneBranch(MachineInstr &MI, MachineInstr &DefMI); 67320374Sdim}; 68320374Sdim} // end anonymous namespace 69320374Sdim 70320374Sdimchar AArch64CondBrTuning::ID = 0; 71320374Sdim 72320374SdimINITIALIZE_PASS(AArch64CondBrTuning, "aarch64-cond-br-tuning", 73320374Sdim AARCH64_CONDBR_TUNING_NAME, false, false) 74320374Sdim 75320374Sdimvoid AArch64CondBrTuning::getAnalysisUsage(AnalysisUsage &AU) const { 76320374Sdim AU.setPreservesCFG(); 77320374Sdim MachineFunctionPass::getAnalysisUsage(AU); 78320374Sdim} 79320374Sdim 80320374SdimMachineInstr *AArch64CondBrTuning::getOperandDef(const MachineOperand &MO) { 81360784Sdim if (!Register::isVirtualRegister(MO.getReg())) 82320374Sdim return nullptr; 83320374Sdim return MRI->getUniqueVRegDef(MO.getReg()); 84320374Sdim} 85320374Sdim 86320374SdimMachineInstr *AArch64CondBrTuning::convertToFlagSetting(MachineInstr &MI, 87320374Sdim bool IsFlagSetting) { 88320374Sdim // If this is already the flag setting version of the instruction (e.g., SUBS) 89320374Sdim // just make sure the implicit-def of NZCV isn't marked dead. 90320374Sdim if (IsFlagSetting) { 91320374Sdim for (unsigned I = MI.getNumExplicitOperands(), E = MI.getNumOperands(); 92320374Sdim I != E; ++I) { 93320374Sdim MachineOperand &MO = MI.getOperand(I); 94320374Sdim if (MO.isReg() && MO.isDead() && MO.getReg() == AArch64::NZCV) 95320374Sdim MO.setIsDead(false); 96320374Sdim } 97320374Sdim return &MI; 98320374Sdim } 99320374Sdim bool Is64Bit; 100320374Sdim unsigned NewOpc = TII->convertToFlagSettingOpc(MI.getOpcode(), Is64Bit); 101360784Sdim Register NewDestReg = MI.getOperand(0).getReg(); 102320374Sdim if (MRI->hasOneNonDBGUse(MI.getOperand(0).getReg())) 103320374Sdim NewDestReg = Is64Bit ? AArch64::XZR : AArch64::WZR; 104320374Sdim 105320374Sdim MachineInstrBuilder MIB = BuildMI(*MI.getParent(), MI, MI.getDebugLoc(), 106320374Sdim TII->get(NewOpc), NewDestReg); 107320374Sdim for (unsigned I = 1, E = MI.getNumOperands(); I != E; ++I) 108320374Sdim MIB.add(MI.getOperand(I)); 109320374Sdim 110320374Sdim return MIB; 111320374Sdim} 112320374Sdim 113320374SdimMachineInstr *AArch64CondBrTuning::convertToCondBr(MachineInstr &MI) { 114320374Sdim AArch64CC::CondCode CC; 115320374Sdim MachineBasicBlock *TargetMBB = TII->getBranchDestBlock(MI); 116320374Sdim switch (MI.getOpcode()) { 117320374Sdim default: 118320374Sdim llvm_unreachable("Unexpected opcode!"); 119320374Sdim 120320374Sdim case AArch64::CBZW: 121320374Sdim case AArch64::CBZX: 122320374Sdim CC = AArch64CC::EQ; 123320374Sdim break; 124320374Sdim case AArch64::CBNZW: 125320374Sdim case AArch64::CBNZX: 126320374Sdim CC = AArch64CC::NE; 127320374Sdim break; 128320374Sdim case AArch64::TBZW: 129320374Sdim case AArch64::TBZX: 130320572Sdim CC = AArch64CC::PL; 131320374Sdim break; 132320374Sdim case AArch64::TBNZW: 133320374Sdim case AArch64::TBNZX: 134320572Sdim CC = AArch64CC::MI; 135320374Sdim break; 136320374Sdim } 137320374Sdim return BuildMI(*MI.getParent(), MI, MI.getDebugLoc(), TII->get(AArch64::Bcc)) 138320374Sdim .addImm(CC) 139320374Sdim .addMBB(TargetMBB); 140320374Sdim} 141320374Sdim 142320374Sdimbool AArch64CondBrTuning::tryToTuneBranch(MachineInstr &MI, 143320374Sdim MachineInstr &DefMI) { 144320374Sdim // We don't want NZCV bits live across blocks. 145320374Sdim if (MI.getParent() != DefMI.getParent()) 146320374Sdim return false; 147320374Sdim 148320374Sdim bool IsFlagSetting = true; 149320374Sdim unsigned MIOpc = MI.getOpcode(); 150320374Sdim MachineInstr *NewCmp = nullptr, *NewBr = nullptr; 151320374Sdim switch (DefMI.getOpcode()) { 152320374Sdim default: 153320374Sdim return false; 154320374Sdim case AArch64::ADDWri: 155320374Sdim case AArch64::ADDWrr: 156320374Sdim case AArch64::ADDWrs: 157320374Sdim case AArch64::ADDWrx: 158320374Sdim case AArch64::ANDWri: 159320374Sdim case AArch64::ANDWrr: 160320374Sdim case AArch64::ANDWrs: 161320374Sdim case AArch64::BICWrr: 162320374Sdim case AArch64::BICWrs: 163320374Sdim case AArch64::SUBWri: 164320374Sdim case AArch64::SUBWrr: 165320374Sdim case AArch64::SUBWrs: 166320374Sdim case AArch64::SUBWrx: 167320374Sdim IsFlagSetting = false; 168320970Sdim LLVM_FALLTHROUGH; 169320374Sdim case AArch64::ADDSWri: 170320374Sdim case AArch64::ADDSWrr: 171320374Sdim case AArch64::ADDSWrs: 172320374Sdim case AArch64::ADDSWrx: 173320374Sdim case AArch64::ANDSWri: 174320374Sdim case AArch64::ANDSWrr: 175320374Sdim case AArch64::ANDSWrs: 176320374Sdim case AArch64::BICSWrr: 177320374Sdim case AArch64::BICSWrs: 178320374Sdim case AArch64::SUBSWri: 179320374Sdim case AArch64::SUBSWrr: 180320374Sdim case AArch64::SUBSWrs: 181320374Sdim case AArch64::SUBSWrx: 182320374Sdim switch (MIOpc) { 183320374Sdim default: 184320374Sdim llvm_unreachable("Unexpected opcode!"); 185320374Sdim 186320374Sdim case AArch64::CBZW: 187320374Sdim case AArch64::CBNZW: 188320374Sdim case AArch64::TBZW: 189320374Sdim case AArch64::TBNZW: 190320374Sdim // Check to see if the TBZ/TBNZ is checking the sign bit. 191320374Sdim if ((MIOpc == AArch64::TBZW || MIOpc == AArch64::TBNZW) && 192320374Sdim MI.getOperand(1).getImm() != 31) 193320374Sdim return false; 194320374Sdim 195320374Sdim // There must not be any instruction between DefMI and MI that clobbers or 196320374Sdim // reads NZCV. 197320374Sdim MachineBasicBlock::iterator I(DefMI), E(MI); 198320374Sdim for (I = std::next(I); I != E; ++I) { 199320374Sdim if (I->modifiesRegister(AArch64::NZCV, TRI) || 200320374Sdim I->readsRegister(AArch64::NZCV, TRI)) 201320374Sdim return false; 202320374Sdim } 203341825Sdim LLVM_DEBUG(dbgs() << " Replacing instructions:\n "); 204341825Sdim LLVM_DEBUG(DefMI.print(dbgs())); 205341825Sdim LLVM_DEBUG(dbgs() << " "); 206341825Sdim LLVM_DEBUG(MI.print(dbgs())); 207320374Sdim 208320374Sdim NewCmp = convertToFlagSetting(DefMI, IsFlagSetting); 209320374Sdim NewBr = convertToCondBr(MI); 210320374Sdim break; 211320374Sdim } 212320374Sdim break; 213320374Sdim 214320374Sdim case AArch64::ADDXri: 215320374Sdim case AArch64::ADDXrr: 216320374Sdim case AArch64::ADDXrs: 217320374Sdim case AArch64::ADDXrx: 218320374Sdim case AArch64::ANDXri: 219320374Sdim case AArch64::ANDXrr: 220320374Sdim case AArch64::ANDXrs: 221320374Sdim case AArch64::BICXrr: 222320374Sdim case AArch64::BICXrs: 223320374Sdim case AArch64::SUBXri: 224320374Sdim case AArch64::SUBXrr: 225320374Sdim case AArch64::SUBXrs: 226320374Sdim case AArch64::SUBXrx: 227320374Sdim IsFlagSetting = false; 228320970Sdim LLVM_FALLTHROUGH; 229320374Sdim case AArch64::ADDSXri: 230320374Sdim case AArch64::ADDSXrr: 231320374Sdim case AArch64::ADDSXrs: 232320374Sdim case AArch64::ADDSXrx: 233320374Sdim case AArch64::ANDSXri: 234320374Sdim case AArch64::ANDSXrr: 235320374Sdim case AArch64::ANDSXrs: 236320374Sdim case AArch64::BICSXrr: 237320374Sdim case AArch64::BICSXrs: 238320374Sdim case AArch64::SUBSXri: 239320374Sdim case AArch64::SUBSXrr: 240320374Sdim case AArch64::SUBSXrs: 241320374Sdim case AArch64::SUBSXrx: 242320374Sdim switch (MIOpc) { 243320374Sdim default: 244320374Sdim llvm_unreachable("Unexpected opcode!"); 245320374Sdim 246320374Sdim case AArch64::CBZX: 247320374Sdim case AArch64::CBNZX: 248320374Sdim case AArch64::TBZX: 249320374Sdim case AArch64::TBNZX: { 250320374Sdim // Check to see if the TBZ/TBNZ is checking the sign bit. 251320374Sdim if ((MIOpc == AArch64::TBZX || MIOpc == AArch64::TBNZX) && 252320374Sdim MI.getOperand(1).getImm() != 63) 253320374Sdim return false; 254320374Sdim // There must not be any instruction between DefMI and MI that clobbers or 255320374Sdim // reads NZCV. 256320374Sdim MachineBasicBlock::iterator I(DefMI), E(MI); 257320374Sdim for (I = std::next(I); I != E; ++I) { 258320374Sdim if (I->modifiesRegister(AArch64::NZCV, TRI) || 259320374Sdim I->readsRegister(AArch64::NZCV, TRI)) 260320374Sdim return false; 261320374Sdim } 262341825Sdim LLVM_DEBUG(dbgs() << " Replacing instructions:\n "); 263341825Sdim LLVM_DEBUG(DefMI.print(dbgs())); 264341825Sdim LLVM_DEBUG(dbgs() << " "); 265341825Sdim LLVM_DEBUG(MI.print(dbgs())); 266320374Sdim 267320374Sdim NewCmp = convertToFlagSetting(DefMI, IsFlagSetting); 268320374Sdim NewBr = convertToCondBr(MI); 269320374Sdim break; 270320374Sdim } 271320374Sdim } 272320374Sdim break; 273320374Sdim } 274320572Sdim (void)NewCmp; (void)NewBr; 275320374Sdim assert(NewCmp && NewBr && "Expected new instructions."); 276320374Sdim 277341825Sdim LLVM_DEBUG(dbgs() << " with instruction:\n "); 278341825Sdim LLVM_DEBUG(NewCmp->print(dbgs())); 279341825Sdim LLVM_DEBUG(dbgs() << " "); 280341825Sdim LLVM_DEBUG(NewBr->print(dbgs())); 281320374Sdim 282320374Sdim // If this was a flag setting version of the instruction, we use the original 283320374Sdim // instruction by just clearing the dead marked on the implicit-def of NCZV. 284320374Sdim // Therefore, we should not erase this instruction. 285320374Sdim if (!IsFlagSetting) 286320374Sdim DefMI.eraseFromParent(); 287320374Sdim MI.eraseFromParent(); 288320374Sdim return true; 289320374Sdim} 290320374Sdim 291320374Sdimbool AArch64CondBrTuning::runOnMachineFunction(MachineFunction &MF) { 292327952Sdim if (skipFunction(MF.getFunction())) 293320374Sdim return false; 294320374Sdim 295341825Sdim LLVM_DEBUG( 296341825Sdim dbgs() << "********** AArch64 Conditional Branch Tuning **********\n" 297341825Sdim << "********** Function: " << MF.getName() << '\n'); 298320374Sdim 299320374Sdim TII = static_cast<const AArch64InstrInfo *>(MF.getSubtarget().getInstrInfo()); 300320374Sdim TRI = MF.getSubtarget().getRegisterInfo(); 301320374Sdim MRI = &MF.getRegInfo(); 302320374Sdim 303320374Sdim bool Changed = false; 304320374Sdim for (MachineBasicBlock &MBB : MF) { 305320374Sdim bool LocalChange = false; 306320374Sdim for (MachineBasicBlock::iterator I = MBB.getFirstTerminator(), 307320374Sdim E = MBB.end(); 308320374Sdim I != E; ++I) { 309320374Sdim MachineInstr &MI = *I; 310320374Sdim switch (MI.getOpcode()) { 311320374Sdim default: 312320374Sdim break; 313320374Sdim case AArch64::CBZW: 314320374Sdim case AArch64::CBZX: 315320374Sdim case AArch64::CBNZW: 316320374Sdim case AArch64::CBNZX: 317320374Sdim case AArch64::TBZW: 318320374Sdim case AArch64::TBZX: 319320374Sdim case AArch64::TBNZW: 320320374Sdim case AArch64::TBNZX: 321320374Sdim MachineInstr *DefMI = getOperandDef(MI.getOperand(0)); 322320374Sdim LocalChange = (DefMI && tryToTuneBranch(MI, *DefMI)); 323320374Sdim break; 324320374Sdim } 325320374Sdim // If the optimization was successful, we can't optimize any other 326320374Sdim // branches because doing so would clobber the NZCV flags. 327320374Sdim if (LocalChange) { 328320374Sdim Changed = true; 329320374Sdim break; 330320374Sdim } 331320374Sdim } 332320374Sdim } 333320374Sdim return Changed; 334320374Sdim} 335320374Sdim 336320374SdimFunctionPass *llvm::createAArch64CondBrTuning() { 337320374Sdim return new AArch64CondBrTuning(); 338320374Sdim} 339