1//===- AArch64MacroFusion.cpp - AArch64 Macro Fusion ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file This file contains the AArch64 implementation of the DAG scheduling
10///  mutation to pair instructions back to back.
11//
12//===----------------------------------------------------------------------===//
13
14#include "AArch64MacroFusion.h"
15#include "AArch64Subtarget.h"
16#include "llvm/CodeGen/MacroFusion.h"
17#include "llvm/CodeGen/TargetInstrInfo.h"
18
19using namespace llvm;
20
21/// CMN, CMP, TST followed by Bcc
22static bool isArithmeticBccPair(const MachineInstr *FirstMI,
23                                const MachineInstr &SecondMI, bool CmpOnly) {
24  if (SecondMI.getOpcode() != AArch64::Bcc)
25    return false;
26
27  // Assume the 1st instr to be a wildcard if it is unspecified.
28  if (FirstMI == nullptr)
29    return true;
30
31  // If we're in CmpOnly mode, we only fuse arithmetic instructions that
32  // discard their result.
33  if (CmpOnly && FirstMI->getOperand(0).isReg() &&
34      !(FirstMI->getOperand(0).getReg() == AArch64::XZR ||
35        FirstMI->getOperand(0).getReg() == AArch64::WZR)) {
36    return false;
37  }
38
39  switch (FirstMI->getOpcode()) {
40  case AArch64::ADDSWri:
41  case AArch64::ADDSWrr:
42  case AArch64::ADDSXri:
43  case AArch64::ADDSXrr:
44  case AArch64::ANDSWri:
45  case AArch64::ANDSWrr:
46  case AArch64::ANDSXri:
47  case AArch64::ANDSXrr:
48  case AArch64::SUBSWri:
49  case AArch64::SUBSWrr:
50  case AArch64::SUBSXri:
51  case AArch64::SUBSXrr:
52  case AArch64::BICSWrr:
53  case AArch64::BICSXrr:
54    return true;
55  case AArch64::ADDSWrs:
56  case AArch64::ADDSXrs:
57  case AArch64::ANDSWrs:
58  case AArch64::ANDSXrs:
59  case AArch64::SUBSWrs:
60  case AArch64::SUBSXrs:
61  case AArch64::BICSWrs:
62  case AArch64::BICSXrs:
63    // Shift value can be 0 making these behave like the "rr" variant...
64    return !AArch64InstrInfo::hasShiftedReg(*FirstMI);
65  }
66
67  return false;
68}
69
70/// ALU operations followed by CBZ/CBNZ.
71static bool isArithmeticCbzPair(const MachineInstr *FirstMI,
72                                const MachineInstr &SecondMI) {
73  if (SecondMI.getOpcode() != AArch64::CBZW &&
74      SecondMI.getOpcode() != AArch64::CBZX &&
75      SecondMI.getOpcode() != AArch64::CBNZW &&
76      SecondMI.getOpcode() != AArch64::CBNZX)
77    return false;
78
79  // Assume the 1st instr to be a wildcard if it is unspecified.
80  if (FirstMI == nullptr)
81    return true;
82
83  switch (FirstMI->getOpcode()) {
84  case AArch64::ADDWri:
85  case AArch64::ADDWrr:
86  case AArch64::ADDXri:
87  case AArch64::ADDXrr:
88  case AArch64::ANDWri:
89  case AArch64::ANDWrr:
90  case AArch64::ANDXri:
91  case AArch64::ANDXrr:
92  case AArch64::EORWri:
93  case AArch64::EORWrr:
94  case AArch64::EORXri:
95  case AArch64::EORXrr:
96  case AArch64::ORRWri:
97  case AArch64::ORRWrr:
98  case AArch64::ORRXri:
99  case AArch64::ORRXrr:
100  case AArch64::SUBWri:
101  case AArch64::SUBWrr:
102  case AArch64::SUBXri:
103  case AArch64::SUBXrr:
104    return true;
105  case AArch64::ADDWrs:
106  case AArch64::ADDXrs:
107  case AArch64::ANDWrs:
108  case AArch64::ANDXrs:
109  case AArch64::SUBWrs:
110  case AArch64::SUBXrs:
111  case AArch64::BICWrs:
112  case AArch64::BICXrs:
113    // Shift value can be 0 making these behave like the "rr" variant...
114    return !AArch64InstrInfo::hasShiftedReg(*FirstMI);
115  }
116
117  return false;
118}
119
120/// AES crypto encoding or decoding.
121static bool isAESPair(const MachineInstr *FirstMI,
122                      const MachineInstr &SecondMI) {
123  // Assume the 1st instr to be a wildcard if it is unspecified.
124  switch (SecondMI.getOpcode()) {
125  // AES encode.
126  case AArch64::AESMCrr:
127  case AArch64::AESMCrrTied:
128    return FirstMI == nullptr || FirstMI->getOpcode() == AArch64::AESErr;
129  // AES decode.
130  case AArch64::AESIMCrr:
131  case AArch64::AESIMCrrTied:
132    return FirstMI == nullptr || FirstMI->getOpcode() == AArch64::AESDrr;
133  }
134
135  return false;
136}
137
138/// AESE/AESD/PMULL + EOR.
139static bool isCryptoEORPair(const MachineInstr *FirstMI,
140                            const MachineInstr &SecondMI) {
141  if (SecondMI.getOpcode() != AArch64::EORv16i8)
142    return false;
143
144  // Assume the 1st instr to be a wildcard if it is unspecified.
145  if (FirstMI == nullptr)
146    return true;
147
148  switch (FirstMI->getOpcode()) {
149  case AArch64::AESErr:
150  case AArch64::AESDrr:
151  case AArch64::PMULLv16i8:
152  case AArch64::PMULLv8i8:
153  case AArch64::PMULLv1i64:
154  case AArch64::PMULLv2i64:
155    return true;
156  }
157
158  return false;
159}
160
161static bool isAdrpAddPair(const MachineInstr *FirstMI,
162                          const MachineInstr &SecondMI) {
163  // Assume the 1st instr to be a wildcard if it is unspecified.
164  if ((FirstMI == nullptr || FirstMI->getOpcode() == AArch64::ADRP) &&
165      SecondMI.getOpcode() == AArch64::ADDXri)
166    return true;
167  return false;
168}
169
170/// Literal generation.
171static bool isLiteralsPair(const MachineInstr *FirstMI,
172                           const MachineInstr &SecondMI) {
173  // Assume the 1st instr to be a wildcard if it is unspecified.
174  // 32 bit immediate.
175  if ((FirstMI == nullptr || FirstMI->getOpcode() == AArch64::MOVZWi) &&
176      (SecondMI.getOpcode() == AArch64::MOVKWi &&
177       SecondMI.getOperand(3).getImm() == 16))
178    return true;
179
180  // Lower half of 64 bit immediate.
181  if((FirstMI == nullptr || FirstMI->getOpcode() == AArch64::MOVZXi) &&
182     (SecondMI.getOpcode() == AArch64::MOVKXi &&
183      SecondMI.getOperand(3).getImm() == 16))
184    return true;
185
186  // Upper half of 64 bit immediate.
187  if ((FirstMI == nullptr ||
188       (FirstMI->getOpcode() == AArch64::MOVKXi &&
189        FirstMI->getOperand(3).getImm() == 32)) &&
190      (SecondMI.getOpcode() == AArch64::MOVKXi &&
191       SecondMI.getOperand(3).getImm() == 48))
192    return true;
193
194  return false;
195}
196
197/// Fuse address generation and loads or stores.
198static bool isAddressLdStPair(const MachineInstr *FirstMI,
199                              const MachineInstr &SecondMI) {
200  switch (SecondMI.getOpcode()) {
201  case AArch64::STRBBui:
202  case AArch64::STRBui:
203  case AArch64::STRDui:
204  case AArch64::STRHHui:
205  case AArch64::STRHui:
206  case AArch64::STRQui:
207  case AArch64::STRSui:
208  case AArch64::STRWui:
209  case AArch64::STRXui:
210  case AArch64::LDRBBui:
211  case AArch64::LDRBui:
212  case AArch64::LDRDui:
213  case AArch64::LDRHHui:
214  case AArch64::LDRHui:
215  case AArch64::LDRQui:
216  case AArch64::LDRSui:
217  case AArch64::LDRWui:
218  case AArch64::LDRXui:
219  case AArch64::LDRSBWui:
220  case AArch64::LDRSBXui:
221  case AArch64::LDRSHWui:
222  case AArch64::LDRSHXui:
223  case AArch64::LDRSWui:
224    // Assume the 1st instr to be a wildcard if it is unspecified.
225    if (FirstMI == nullptr)
226      return true;
227
228   switch (FirstMI->getOpcode()) {
229    case AArch64::ADR:
230      return SecondMI.getOperand(2).getImm() == 0;
231    case AArch64::ADRP:
232      return true;
233    }
234  }
235
236  return false;
237}
238
239/// Compare and conditional select.
240static bool isCCSelectPair(const MachineInstr *FirstMI,
241                           const MachineInstr &SecondMI) {
242  // 32 bits
243  if (SecondMI.getOpcode() == AArch64::CSELWr) {
244    // Assume the 1st instr to be a wildcard if it is unspecified.
245    if (FirstMI == nullptr)
246      return true;
247
248    if (FirstMI->definesRegister(AArch64::WZR))
249      switch (FirstMI->getOpcode()) {
250      case AArch64::SUBSWrs:
251        return !AArch64InstrInfo::hasShiftedReg(*FirstMI);
252      case AArch64::SUBSWrx:
253        return !AArch64InstrInfo::hasExtendedReg(*FirstMI);
254      case AArch64::SUBSWrr:
255      case AArch64::SUBSWri:
256        return true;
257      }
258  }
259
260  // 64 bits
261  if (SecondMI.getOpcode() == AArch64::CSELXr) {
262    // Assume the 1st instr to be a wildcard if it is unspecified.
263    if (FirstMI == nullptr)
264      return true;
265
266    if (FirstMI->definesRegister(AArch64::XZR))
267      switch (FirstMI->getOpcode()) {
268      case AArch64::SUBSXrs:
269        return !AArch64InstrInfo::hasShiftedReg(*FirstMI);
270      case AArch64::SUBSXrx:
271      case AArch64::SUBSXrx64:
272        return !AArch64InstrInfo::hasExtendedReg(*FirstMI);
273      case AArch64::SUBSXrr:
274      case AArch64::SUBSXri:
275        return true;
276      }
277  }
278
279  return false;
280}
281
282// Arithmetic and logic.
283static bool isArithmeticLogicPair(const MachineInstr *FirstMI,
284                                  const MachineInstr &SecondMI) {
285  if (AArch64InstrInfo::hasShiftedReg(SecondMI))
286    return false;
287
288  switch (SecondMI.getOpcode()) {
289  // Arithmetic
290  case AArch64::ADDWrr:
291  case AArch64::ADDXrr:
292  case AArch64::SUBWrr:
293  case AArch64::SUBXrr:
294  case AArch64::ADDWrs:
295  case AArch64::ADDXrs:
296  case AArch64::SUBWrs:
297  case AArch64::SUBXrs:
298  // Logic
299  case AArch64::ANDWrr:
300  case AArch64::ANDXrr:
301  case AArch64::BICWrr:
302  case AArch64::BICXrr:
303  case AArch64::EONWrr:
304  case AArch64::EONXrr:
305  case AArch64::EORWrr:
306  case AArch64::EORXrr:
307  case AArch64::ORNWrr:
308  case AArch64::ORNXrr:
309  case AArch64::ORRWrr:
310  case AArch64::ORRXrr:
311  case AArch64::ANDWrs:
312  case AArch64::ANDXrs:
313  case AArch64::BICWrs:
314  case AArch64::BICXrs:
315  case AArch64::EONWrs:
316  case AArch64::EONXrs:
317  case AArch64::EORWrs:
318  case AArch64::EORXrs:
319  case AArch64::ORNWrs:
320  case AArch64::ORNXrs:
321  case AArch64::ORRWrs:
322  case AArch64::ORRXrs:
323    // Assume the 1st instr to be a wildcard if it is unspecified.
324    if (FirstMI == nullptr)
325      return true;
326
327    // Arithmetic
328    switch (FirstMI->getOpcode()) {
329    case AArch64::ADDWrr:
330    case AArch64::ADDXrr:
331    case AArch64::ADDSWrr:
332    case AArch64::ADDSXrr:
333    case AArch64::SUBWrr:
334    case AArch64::SUBXrr:
335    case AArch64::SUBSWrr:
336    case AArch64::SUBSXrr:
337      return true;
338    case AArch64::ADDWrs:
339    case AArch64::ADDXrs:
340    case AArch64::ADDSWrs:
341    case AArch64::ADDSXrs:
342    case AArch64::SUBWrs:
343    case AArch64::SUBXrs:
344    case AArch64::SUBSWrs:
345    case AArch64::SUBSXrs:
346      return !AArch64InstrInfo::hasShiftedReg(*FirstMI);
347    }
348    break;
349
350  // Arithmetic, setting flags.
351  case AArch64::ADDSWrr:
352  case AArch64::ADDSXrr:
353  case AArch64::SUBSWrr:
354  case AArch64::SUBSXrr:
355  case AArch64::ADDSWrs:
356  case AArch64::ADDSXrs:
357  case AArch64::SUBSWrs:
358  case AArch64::SUBSXrs:
359    // Assume the 1st instr to be a wildcard if it is unspecified.
360    if (FirstMI == nullptr)
361      return true;
362
363    // Arithmetic, not setting flags.
364    switch (FirstMI->getOpcode()) {
365    case AArch64::ADDWrr:
366    case AArch64::ADDXrr:
367    case AArch64::SUBWrr:
368    case AArch64::SUBXrr:
369      return true;
370    case AArch64::ADDWrs:
371    case AArch64::ADDXrs:
372    case AArch64::SUBWrs:
373    case AArch64::SUBXrs:
374      return !AArch64InstrInfo::hasShiftedReg(*FirstMI);
375    }
376    break;
377  }
378
379  return false;
380}
381
382// "(A + B) + 1" or "(A - B) - 1"
383static bool isAddSub2RegAndConstOnePair(const MachineInstr *FirstMI,
384                                        const MachineInstr &SecondMI) {
385  bool NeedsSubtract = false;
386
387  // The 2nd instr must be an add-immediate or subtract-immediate.
388  switch (SecondMI.getOpcode()) {
389  case AArch64::SUBWri:
390  case AArch64::SUBXri:
391    NeedsSubtract = true;
392    [[fallthrough]];
393  case AArch64::ADDWri:
394  case AArch64::ADDXri:
395    break;
396
397  default:
398    return false;
399  }
400
401  // The immediate in the 2nd instr must be "1".
402  if (!SecondMI.getOperand(2).isImm() || SecondMI.getOperand(2).getImm() != 1) {
403    return false;
404  }
405
406  // Assume the 1st instr to be a wildcard if it is unspecified.
407  if (FirstMI == nullptr) {
408    return true;
409  }
410
411  switch (FirstMI->getOpcode()) {
412  case AArch64::SUBWrs:
413  case AArch64::SUBXrs:
414    if (AArch64InstrInfo::hasShiftedReg(*FirstMI))
415      return false;
416    [[fallthrough]];
417  case AArch64::SUBWrr:
418  case AArch64::SUBXrr:
419    if (NeedsSubtract) {
420      return true;
421    }
422    break;
423
424  case AArch64::ADDWrs:
425  case AArch64::ADDXrs:
426    if (AArch64InstrInfo::hasShiftedReg(*FirstMI))
427      return false;
428    [[fallthrough]];
429  case AArch64::ADDWrr:
430  case AArch64::ADDXrr:
431    if (!NeedsSubtract) {
432      return true;
433    }
434    break;
435  }
436
437  return false;
438}
439
440/// \brief Check if the instr pair, FirstMI and SecondMI, should be fused
441/// together. Given SecondMI, when FirstMI is unspecified, then check if
442/// SecondMI may be part of a fused pair at all.
443static bool shouldScheduleAdjacent(const TargetInstrInfo &TII,
444                                   const TargetSubtargetInfo &TSI,
445                                   const MachineInstr *FirstMI,
446                                   const MachineInstr &SecondMI) {
447  const AArch64Subtarget &ST = static_cast<const AArch64Subtarget&>(TSI);
448
449  // All checking functions assume that the 1st instr is a wildcard if it is
450  // unspecified.
451  if (ST.hasCmpBccFusion() || ST.hasArithmeticBccFusion()) {
452    bool CmpOnly = !ST.hasArithmeticBccFusion();
453    if (isArithmeticBccPair(FirstMI, SecondMI, CmpOnly))
454      return true;
455  }
456  if (ST.hasArithmeticCbzFusion() && isArithmeticCbzPair(FirstMI, SecondMI))
457    return true;
458  if (ST.hasFuseAES() && isAESPair(FirstMI, SecondMI))
459    return true;
460  if (ST.hasFuseCryptoEOR() && isCryptoEORPair(FirstMI, SecondMI))
461    return true;
462  if (ST.hasFuseAdrpAdd() && isAdrpAddPair(FirstMI, SecondMI))
463    return true;
464  if (ST.hasFuseLiterals() && isLiteralsPair(FirstMI, SecondMI))
465    return true;
466  if (ST.hasFuseAddress() && isAddressLdStPair(FirstMI, SecondMI))
467    return true;
468  if (ST.hasFuseCCSelect() && isCCSelectPair(FirstMI, SecondMI))
469    return true;
470  if (ST.hasFuseArithmeticLogic() && isArithmeticLogicPair(FirstMI, SecondMI))
471    return true;
472  if (ST.hasFuseAddSub2RegAndConstOne() &&
473      isAddSub2RegAndConstOnePair(FirstMI, SecondMI))
474    return true;
475
476  return false;
477}
478
479std::unique_ptr<ScheduleDAGMutation>
480llvm::createAArch64MacroFusionDAGMutation() {
481  return createMacroFusionDAGMutation(shouldScheduleAdjacent);
482}
483