AArch64PBQPRegAlloc.cpp revision 360784
1//===-- AArch64PBQPRegAlloc.cpp - AArch64 specific PBQP constraints -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This file contains the AArch64 / Cortex-A57 specific register allocation
9// constraints for use by the PBQP register allocator.
10//
11// It is essentially a transcription of what is contained in
12// AArch64A57FPLoadBalancing, which tries to use a balanced
13// mix of odd and even D-registers when performing a critical sequence of
14// independent, non-quadword FP/ASIMD floating-point multiply-accumulates.
15//===----------------------------------------------------------------------===//
16
17#define DEBUG_TYPE "aarch64-pbqp"
18
19#include "AArch64PBQPRegAlloc.h"
20#include "AArch64.h"
21#include "AArch64RegisterInfo.h"
22#include "llvm/CodeGen/LiveIntervals.h"
23#include "llvm/CodeGen/MachineBasicBlock.h"
24#include "llvm/CodeGen/MachineFunction.h"
25#include "llvm/CodeGen/MachineRegisterInfo.h"
26#include "llvm/CodeGen/RegAllocPBQP.h"
27#include "llvm/Support/Debug.h"
28#include "llvm/Support/ErrorHandling.h"
29#include "llvm/Support/raw_ostream.h"
30
31using namespace llvm;
32
33namespace {
34
35#ifndef NDEBUG
36bool isFPReg(unsigned reg) {
37  return AArch64::FPR32RegClass.contains(reg) ||
38         AArch64::FPR64RegClass.contains(reg) ||
39         AArch64::FPR128RegClass.contains(reg);
40}
41#endif
42
43bool isOdd(unsigned reg) {
44  switch (reg) {
45  default:
46    llvm_unreachable("Register is not from the expected class !");
47  case AArch64::S1:
48  case AArch64::S3:
49  case AArch64::S5:
50  case AArch64::S7:
51  case AArch64::S9:
52  case AArch64::S11:
53  case AArch64::S13:
54  case AArch64::S15:
55  case AArch64::S17:
56  case AArch64::S19:
57  case AArch64::S21:
58  case AArch64::S23:
59  case AArch64::S25:
60  case AArch64::S27:
61  case AArch64::S29:
62  case AArch64::S31:
63  case AArch64::D1:
64  case AArch64::D3:
65  case AArch64::D5:
66  case AArch64::D7:
67  case AArch64::D9:
68  case AArch64::D11:
69  case AArch64::D13:
70  case AArch64::D15:
71  case AArch64::D17:
72  case AArch64::D19:
73  case AArch64::D21:
74  case AArch64::D23:
75  case AArch64::D25:
76  case AArch64::D27:
77  case AArch64::D29:
78  case AArch64::D31:
79  case AArch64::Q1:
80  case AArch64::Q3:
81  case AArch64::Q5:
82  case AArch64::Q7:
83  case AArch64::Q9:
84  case AArch64::Q11:
85  case AArch64::Q13:
86  case AArch64::Q15:
87  case AArch64::Q17:
88  case AArch64::Q19:
89  case AArch64::Q21:
90  case AArch64::Q23:
91  case AArch64::Q25:
92  case AArch64::Q27:
93  case AArch64::Q29:
94  case AArch64::Q31:
95    return true;
96  case AArch64::S0:
97  case AArch64::S2:
98  case AArch64::S4:
99  case AArch64::S6:
100  case AArch64::S8:
101  case AArch64::S10:
102  case AArch64::S12:
103  case AArch64::S14:
104  case AArch64::S16:
105  case AArch64::S18:
106  case AArch64::S20:
107  case AArch64::S22:
108  case AArch64::S24:
109  case AArch64::S26:
110  case AArch64::S28:
111  case AArch64::S30:
112  case AArch64::D0:
113  case AArch64::D2:
114  case AArch64::D4:
115  case AArch64::D6:
116  case AArch64::D8:
117  case AArch64::D10:
118  case AArch64::D12:
119  case AArch64::D14:
120  case AArch64::D16:
121  case AArch64::D18:
122  case AArch64::D20:
123  case AArch64::D22:
124  case AArch64::D24:
125  case AArch64::D26:
126  case AArch64::D28:
127  case AArch64::D30:
128  case AArch64::Q0:
129  case AArch64::Q2:
130  case AArch64::Q4:
131  case AArch64::Q6:
132  case AArch64::Q8:
133  case AArch64::Q10:
134  case AArch64::Q12:
135  case AArch64::Q14:
136  case AArch64::Q16:
137  case AArch64::Q18:
138  case AArch64::Q20:
139  case AArch64::Q22:
140  case AArch64::Q24:
141  case AArch64::Q26:
142  case AArch64::Q28:
143  case AArch64::Q30:
144    return false;
145
146  }
147}
148
149bool haveSameParity(unsigned reg1, unsigned reg2) {
150  assert(isFPReg(reg1) && "Expecting an FP register for reg1");
151  assert(isFPReg(reg2) && "Expecting an FP register for reg2");
152
153  return isOdd(reg1) == isOdd(reg2);
154}
155
156}
157
158bool A57ChainingConstraint::addIntraChainConstraint(PBQPRAGraph &G, unsigned Rd,
159                                                 unsigned Ra) {
160  if (Rd == Ra)
161    return false;
162
163  LiveIntervals &LIs = G.getMetadata().LIS;
164
165  if (Register::isPhysicalRegister(Rd) || Register::isPhysicalRegister(Ra)) {
166    LLVM_DEBUG(dbgs() << "Rd is a physical reg:"
167                      << Register::isPhysicalRegister(Rd) << '\n');
168    LLVM_DEBUG(dbgs() << "Ra is a physical reg:"
169                      << Register::isPhysicalRegister(Ra) << '\n');
170    return false;
171  }
172
173  PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
174  PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(Ra);
175
176  const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
177    &G.getNodeMetadata(node1).getAllowedRegs();
178  const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRaAllowed =
179    &G.getNodeMetadata(node2).getAllowedRegs();
180
181  PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
182
183  // The edge does not exist. Create one with the appropriate interference
184  // costs.
185  if (edge == G.invalidEdgeId()) {
186    const LiveInterval &ld = LIs.getInterval(Rd);
187    const LiveInterval &la = LIs.getInterval(Ra);
188    bool livesOverlap = ld.overlaps(la);
189
190    PBQPRAGraph::RawMatrix costs(vRdAllowed->size() + 1,
191                                 vRaAllowed->size() + 1, 0);
192    for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
193      unsigned pRd = (*vRdAllowed)[i];
194      for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
195        unsigned pRa = (*vRaAllowed)[j];
196        if (livesOverlap && TRI->regsOverlap(pRd, pRa))
197          costs[i + 1][j + 1] = std::numeric_limits<PBQP::PBQPNum>::infinity();
198        else
199          costs[i + 1][j + 1] = haveSameParity(pRd, pRa) ? 0.0 : 1.0;
200      }
201    }
202    G.addEdge(node1, node2, std::move(costs));
203    return true;
204  }
205
206  if (G.getEdgeNode1Id(edge) == node2) {
207    std::swap(node1, node2);
208    std::swap(vRdAllowed, vRaAllowed);
209  }
210
211  // Enforce minCost(sameParity(RaClass)) > maxCost(otherParity(RdClass))
212  PBQPRAGraph::RawMatrix costs(G.getEdgeCosts(edge));
213  for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
214    unsigned pRd = (*vRdAllowed)[i];
215
216    // Get the maximum cost (excluding unallocatable reg) for same parity
217    // registers
218    PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
219    for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
220      unsigned pRa = (*vRaAllowed)[j];
221      if (haveSameParity(pRd, pRa))
222        if (costs[i + 1][j + 1] !=
223                std::numeric_limits<PBQP::PBQPNum>::infinity() &&
224            costs[i + 1][j + 1] > sameParityMax)
225          sameParityMax = costs[i + 1][j + 1];
226    }
227
228    // Ensure all registers with a different parity have a higher cost
229    // than sameParityMax
230    for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
231      unsigned pRa = (*vRaAllowed)[j];
232      if (!haveSameParity(pRd, pRa))
233        if (sameParityMax > costs[i + 1][j + 1])
234          costs[i + 1][j + 1] = sameParityMax + 1.0;
235    }
236  }
237  G.updateEdgeCosts(edge, std::move(costs));
238
239  return true;
240}
241
242void A57ChainingConstraint::addInterChainConstraint(PBQPRAGraph &G, unsigned Rd,
243                                                 unsigned Ra) {
244  LiveIntervals &LIs = G.getMetadata().LIS;
245
246  // Do some Chain management
247  if (Chains.count(Ra)) {
248    if (Rd != Ra) {
249      LLVM_DEBUG(dbgs() << "Moving acc chain from " << printReg(Ra, TRI)
250                        << " to " << printReg(Rd, TRI) << '\n';);
251      Chains.remove(Ra);
252      Chains.insert(Rd);
253    }
254  } else {
255    LLVM_DEBUG(dbgs() << "Creating new acc chain for " << printReg(Rd, TRI)
256                      << '\n';);
257    Chains.insert(Rd);
258  }
259
260  PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
261
262  const LiveInterval &ld = LIs.getInterval(Rd);
263  for (auto r : Chains) {
264    // Skip self
265    if (r == Rd)
266      continue;
267
268    const LiveInterval &lr = LIs.getInterval(r);
269    if (ld.overlaps(lr)) {
270      const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
271        &G.getNodeMetadata(node1).getAllowedRegs();
272
273      PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(r);
274      const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRrAllowed =
275        &G.getNodeMetadata(node2).getAllowedRegs();
276
277      PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
278      assert(edge != G.invalidEdgeId() &&
279             "PBQP error ! The edge should exist !");
280
281      LLVM_DEBUG(dbgs() << "Refining constraint !\n";);
282
283      if (G.getEdgeNode1Id(edge) == node2) {
284        std::swap(node1, node2);
285        std::swap(vRdAllowed, vRrAllowed);
286      }
287
288      // Enforce that cost is higher with all other Chains of the same parity
289      PBQP::Matrix costs(G.getEdgeCosts(edge));
290      for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
291        unsigned pRd = (*vRdAllowed)[i];
292
293        // Get the maximum cost (excluding unallocatable reg) for all other
294        // parity registers
295        PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
296        for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
297          unsigned pRa = (*vRrAllowed)[j];
298          if (!haveSameParity(pRd, pRa))
299            if (costs[i + 1][j + 1] !=
300                    std::numeric_limits<PBQP::PBQPNum>::infinity() &&
301                costs[i + 1][j + 1] > sameParityMax)
302              sameParityMax = costs[i + 1][j + 1];
303        }
304
305        // Ensure all registers with same parity have a higher cost
306        // than sameParityMax
307        for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
308          unsigned pRa = (*vRrAllowed)[j];
309          if (haveSameParity(pRd, pRa))
310            if (sameParityMax > costs[i + 1][j + 1])
311              costs[i + 1][j + 1] = sameParityMax + 1.0;
312        }
313      }
314      G.updateEdgeCosts(edge, std::move(costs));
315    }
316  }
317}
318
319static bool regJustKilledBefore(const LiveIntervals &LIs, unsigned reg,
320                                const MachineInstr &MI) {
321  const LiveInterval &LI = LIs.getInterval(reg);
322  SlotIndex SI = LIs.getInstructionIndex(MI);
323  return LI.expiredAt(SI);
324}
325
326void A57ChainingConstraint::apply(PBQPRAGraph &G) {
327  const MachineFunction &MF = G.getMetadata().MF;
328  LiveIntervals &LIs = G.getMetadata().LIS;
329
330  TRI = MF.getSubtarget().getRegisterInfo();
331  LLVM_DEBUG(MF.dump());
332
333  for (const auto &MBB: MF) {
334    Chains.clear(); // FIXME: really needed ? Could not work at MF level ?
335
336    for (const auto &MI: MBB) {
337
338      // Forget Chains which have expired
339      for (auto r : Chains) {
340        SmallVector<unsigned, 8> toDel;
341        if(regJustKilledBefore(LIs, r, MI)) {
342          LLVM_DEBUG(dbgs() << "Killing chain " << printReg(r, TRI) << " at ";
343                     MI.print(dbgs()););
344          toDel.push_back(r);
345        }
346
347        while (!toDel.empty()) {
348          Chains.remove(toDel.back());
349          toDel.pop_back();
350        }
351      }
352
353      switch (MI.getOpcode()) {
354      case AArch64::FMSUBSrrr:
355      case AArch64::FMADDSrrr:
356      case AArch64::FNMSUBSrrr:
357      case AArch64::FNMADDSrrr:
358      case AArch64::FMSUBDrrr:
359      case AArch64::FMADDDrrr:
360      case AArch64::FNMSUBDrrr:
361      case AArch64::FNMADDDrrr: {
362        Register Rd = MI.getOperand(0).getReg();
363        Register Ra = MI.getOperand(3).getReg();
364
365        if (addIntraChainConstraint(G, Rd, Ra))
366          addInterChainConstraint(G, Rd, Ra);
367        break;
368      }
369
370      case AArch64::FMLAv2f32:
371      case AArch64::FMLSv2f32: {
372        Register Rd = MI.getOperand(0).getReg();
373        addInterChainConstraint(G, Rd, Rd);
374        break;
375      }
376
377      default:
378        break;
379      }
380    }
381  }
382}
383