1//===- GlobalCombinerEmitter.cpp - Generate a combiner --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file Generate a combiner implementation for GlobalISel from a declarative
10/// syntax
11///
12//===----------------------------------------------------------------------===//
13
14#include "llvm/ADT/SmallSet.h"
15#include "llvm/ADT/Statistic.h"
16#include "llvm/ADT/StringSet.h"
17#include "llvm/Support/CommandLine.h"
18#include "llvm/Support/ScopedPrinter.h"
19#include "llvm/Support/Timer.h"
20#include "llvm/TableGen/Error.h"
21#include "llvm/TableGen/StringMatcher.h"
22#include "llvm/TableGen/TableGenBackend.h"
23#include "CodeGenTarget.h"
24#include "GlobalISel/CodeExpander.h"
25#include "GlobalISel/CodeExpansions.h"
26#include "GlobalISel/GIMatchDag.h"
27#include "GlobalISel/GIMatchTree.h"
28#include <cstdint>
29
30using namespace llvm;
31
32#define DEBUG_TYPE "gicombiner-emitter"
33
34// FIXME: Use ALWAYS_ENABLED_STATISTIC once it's available.
35unsigned NumPatternTotal = 0;
36STATISTIC(NumPatternTotalStatistic, "Total number of patterns");
37
38cl::OptionCategory
39    GICombinerEmitterCat("Options for -gen-global-isel-combiner");
40static cl::list<std::string>
41    SelectedCombiners("combiners", cl::desc("Emit the specified combiners"),
42                      cl::cat(GICombinerEmitterCat), cl::CommaSeparated);
43static cl::opt<bool> ShowExpansions(
44    "gicombiner-show-expansions",
45    cl::desc("Use C++ comments to indicate occurence of code expansion"),
46    cl::cat(GICombinerEmitterCat));
47static cl::opt<bool> StopAfterParse(
48    "gicombiner-stop-after-parse",
49    cl::desc("Stop processing after parsing rules and dump state"),
50    cl::cat(GICombinerEmitterCat));
51static cl::opt<bool> StopAfterBuild(
52    "gicombiner-stop-after-build",
53    cl::desc("Stop processing after building the match tree"),
54    cl::cat(GICombinerEmitterCat));
55
56namespace {
57typedef uint64_t RuleID;
58
59// We're going to be referencing the same small strings quite a lot for operand
60// names and the like. Make their lifetime management simple with a global
61// string table.
62StringSet<> StrTab;
63
64StringRef insertStrTab(StringRef S) {
65  if (S.empty())
66    return S;
67  return StrTab.insert(S).first->first();
68}
69
70class format_partition_name {
71  const GIMatchTree &Tree;
72  unsigned Idx;
73
74public:
75  format_partition_name(const GIMatchTree &Tree, unsigned Idx)
76      : Tree(Tree), Idx(Idx) {}
77  void print(raw_ostream &OS) const {
78    Tree.getPartitioner()->emitPartitionName(OS, Idx);
79  }
80};
81raw_ostream &operator<<(raw_ostream &OS, const format_partition_name &Fmt) {
82  Fmt.print(OS);
83  return OS;
84}
85
86/// Declares data that is passed from the match stage to the apply stage.
87class MatchDataInfo {
88  /// The symbol used in the tablegen patterns
89  StringRef PatternSymbol;
90  /// The data type for the variable
91  StringRef Type;
92  /// The name of the variable as declared in the generated matcher.
93  std::string VariableName;
94
95public:
96  MatchDataInfo(StringRef PatternSymbol, StringRef Type, StringRef VariableName)
97      : PatternSymbol(PatternSymbol), Type(Type), VariableName(VariableName) {}
98
99  StringRef getPatternSymbol() const { return PatternSymbol; };
100  StringRef getType() const { return Type; };
101  StringRef getVariableName() const { return VariableName; };
102};
103
104class RootInfo {
105  StringRef PatternSymbol;
106
107public:
108  RootInfo(StringRef PatternSymbol) : PatternSymbol(PatternSymbol) {}
109
110  StringRef getPatternSymbol() const { return PatternSymbol; }
111};
112
113class CombineRule {
114public:
115
116  using const_matchdata_iterator = std::vector<MatchDataInfo>::const_iterator;
117
118  struct VarInfo {
119    const GIMatchDagInstr *N;
120    const GIMatchDagOperand *Op;
121    const DagInit *Matcher;
122
123  public:
124    VarInfo(const GIMatchDagInstr *N, const GIMatchDagOperand *Op,
125            const DagInit *Matcher)
126        : N(N), Op(Op), Matcher(Matcher) {}
127  };
128
129protected:
130  /// A unique ID for this rule
131  /// ID's are used for debugging and run-time disabling of rules among other
132  /// things.
133  RuleID ID;
134
135  /// A unique ID that can be used for anonymous objects belonging to this rule.
136  /// Used to create unique names in makeNameForAnon*() without making tests
137  /// overly fragile.
138  unsigned UID = 0;
139
140  /// The record defining this rule.
141  const Record &TheDef;
142
143  /// The roots of a match. These are the leaves of the DAG that are closest to
144  /// the end of the function. I.e. the nodes that are encountered without
145  /// following any edges of the DAG described by the pattern as we work our way
146  /// from the bottom of the function to the top.
147  std::vector<RootInfo> Roots;
148
149  GIMatchDag MatchDag;
150
151  /// A block of arbitrary C++ to finish testing the match.
152  /// FIXME: This is a temporary measure until we have actual pattern matching
153  const CodeInit *MatchingFixupCode = nullptr;
154
155  /// The MatchData defined by the match stage and required by the apply stage.
156  /// This allows the plumbing of arbitrary data from C++ predicates between the
157  /// stages.
158  ///
159  /// For example, suppose you have:
160  ///   %A = <some-constant-expr>
161  ///   %0 = G_ADD %1, %A
162  /// you could define a GIMatchPredicate that walks %A, constant folds as much
163  /// as possible and returns an APInt containing the discovered constant. You
164  /// could then declare:
165  ///   def apint : GIDefMatchData<"APInt">;
166  /// add it to the rule with:
167  ///   (defs root:$root, apint:$constant)
168  /// evaluate it in the pattern with a C++ function that takes a
169  /// MachineOperand& and an APInt& with:
170  ///   (match [{MIR %root = G_ADD %0, %A }],
171  ///             (constantfold operand:$A, apint:$constant))
172  /// and finally use it in the apply stage with:
173  ///   (apply (create_operand
174  ///                [{ MachineOperand::CreateImm(${constant}.getZExtValue());
175  ///                ]}, apint:$constant),
176  ///             [{MIR %root = FOO %0, %constant }])
177  std::vector<MatchDataInfo> MatchDataDecls;
178
179  void declareMatchData(StringRef PatternSymbol, StringRef Type,
180                        StringRef VarName);
181
182  bool parseInstructionMatcher(const CodeGenTarget &Target, StringInit *ArgName,
183                               const Init &Arg,
184                               StringMap<std::vector<VarInfo>> &NamedEdgeDefs,
185                               StringMap<std::vector<VarInfo>> &NamedEdgeUses);
186  bool parseWipMatchOpcodeMatcher(const CodeGenTarget &Target,
187                                  StringInit *ArgName, const Init &Arg);
188
189public:
190  CombineRule(const CodeGenTarget &Target, GIMatchDagContext &Ctx, RuleID ID,
191              const Record &R)
192      : ID(ID), TheDef(R), MatchDag(Ctx) {}
193  CombineRule(const CombineRule &) = delete;
194
195  bool parseDefs();
196  bool parseMatcher(const CodeGenTarget &Target);
197
198  RuleID getID() const { return ID; }
199  unsigned allocUID() { return UID++; }
200  StringRef getName() const { return TheDef.getName(); }
201  const Record &getDef() const { return TheDef; }
202  const CodeInit *getMatchingFixupCode() const { return MatchingFixupCode; }
203  size_t getNumRoots() const { return Roots.size(); }
204
205  GIMatchDag &getMatchDag() { return MatchDag; }
206  const GIMatchDag &getMatchDag() const { return MatchDag; }
207
208  using const_root_iterator = std::vector<RootInfo>::const_iterator;
209  const_root_iterator roots_begin() const { return Roots.begin(); }
210  const_root_iterator roots_end() const { return Roots.end(); }
211  iterator_range<const_root_iterator> roots() const {
212    return llvm::make_range(Roots.begin(), Roots.end());
213  }
214
215  iterator_range<const_matchdata_iterator> matchdata_decls() const {
216    return make_range(MatchDataDecls.begin(), MatchDataDecls.end());
217  }
218
219  /// Export expansions for this rule
220  void declareExpansions(CodeExpansions &Expansions) const {
221    for (const auto &I : matchdata_decls())
222      Expansions.declare(I.getPatternSymbol(), I.getVariableName());
223  }
224
225  /// The matcher will begin from the roots and will perform the match by
226  /// traversing the edges to cover the whole DAG. This function reverses DAG
227  /// edges such that everything is reachable from a root. This is part of the
228  /// preparation work for flattening the DAG into a tree.
229  void reorientToRoots() {
230    SmallSet<const GIMatchDagInstr *, 5> Roots;
231    SmallSet<const GIMatchDagInstr *, 5> Visited;
232    SmallSet<GIMatchDagEdge *, 20> EdgesRemaining;
233
234    for (auto &I : MatchDag.roots()) {
235      Roots.insert(I);
236      Visited.insert(I);
237    }
238    for (auto &I : MatchDag.edges())
239      EdgesRemaining.insert(I);
240
241    bool Progressed = false;
242    SmallSet<GIMatchDagEdge *, 20> EdgesToRemove;
243    while (!EdgesRemaining.empty()) {
244      for (auto EI = EdgesRemaining.begin(), EE = EdgesRemaining.end();
245           EI != EE; ++EI) {
246        if (Visited.count((*EI)->getFromMI())) {
247          if (Roots.count((*EI)->getToMI()))
248            PrintError(TheDef.getLoc(), "One or more roots are unnecessary");
249          Visited.insert((*EI)->getToMI());
250          EdgesToRemove.insert(*EI);
251          Progressed = true;
252        }
253      }
254      for (GIMatchDagEdge *ToRemove : EdgesToRemove)
255        EdgesRemaining.erase(ToRemove);
256      EdgesToRemove.clear();
257
258      for (auto EI = EdgesRemaining.begin(), EE = EdgesRemaining.end();
259           EI != EE; ++EI) {
260        if (Visited.count((*EI)->getToMI())) {
261          (*EI)->reverse();
262          Visited.insert((*EI)->getToMI());
263          EdgesToRemove.insert(*EI);
264          Progressed = true;
265        }
266        for (GIMatchDagEdge *ToRemove : EdgesToRemove)
267          EdgesRemaining.erase(ToRemove);
268        EdgesToRemove.clear();
269      }
270
271      if (!Progressed) {
272        LLVM_DEBUG(dbgs() << "No progress\n");
273        return;
274      }
275      Progressed = false;
276    }
277  }
278};
279
280/// A convenience function to check that an Init refers to a specific def. This
281/// is primarily useful for testing for defs and similar in DagInit's since
282/// DagInit's support any type inside them.
283static bool isSpecificDef(const Init &N, StringRef Def) {
284  if (const DefInit *OpI = dyn_cast<DefInit>(&N))
285    if (OpI->getDef()->getName() == Def)
286      return true;
287  return false;
288}
289
290/// A convenience function to check that an Init refers to a def that is a
291/// subclass of the given class and coerce it to a def if it is. This is
292/// primarily useful for testing for subclasses of GIMatchKind and similar in
293/// DagInit's since DagInit's support any type inside them.
294static Record *getDefOfSubClass(const Init &N, StringRef Cls) {
295  if (const DefInit *OpI = dyn_cast<DefInit>(&N))
296    if (OpI->getDef()->isSubClassOf(Cls))
297      return OpI->getDef();
298  return nullptr;
299}
300
301/// A convenience function to check that an Init refers to a dag whose operator
302/// is a specific def and coerce it to a dag if it is. This is primarily useful
303/// for testing for subclasses of GIMatchKind and similar in DagInit's since
304/// DagInit's support any type inside them.
305static const DagInit *getDagWithSpecificOperator(const Init &N,
306                                                 StringRef Name) {
307  if (const DagInit *I = dyn_cast<DagInit>(&N))
308    if (I->getNumArgs() > 0)
309      if (const DefInit *OpI = dyn_cast<DefInit>(I->getOperator()))
310        if (OpI->getDef()->getName() == Name)
311          return I;
312  return nullptr;
313}
314
315/// A convenience function to check that an Init refers to a dag whose operator
316/// is a def that is a subclass of the given class and coerce it to a dag if it
317/// is. This is primarily useful for testing for subclasses of GIMatchKind and
318/// similar in DagInit's since DagInit's support any type inside them.
319static const DagInit *getDagWithOperatorOfSubClass(const Init &N,
320                                                   StringRef Cls) {
321  if (const DagInit *I = dyn_cast<DagInit>(&N))
322    if (I->getNumArgs() > 0)
323      if (const DefInit *OpI = dyn_cast<DefInit>(I->getOperator()))
324        if (OpI->getDef()->isSubClassOf(Cls))
325          return I;
326  return nullptr;
327}
328
329StringRef makeNameForAnonInstr(CombineRule &Rule) {
330  return insertStrTab(to_string(
331      format("__anon%" PRIu64 "_%u", Rule.getID(), Rule.allocUID())));
332}
333
334StringRef makeDebugName(CombineRule &Rule, StringRef Name) {
335  return insertStrTab(Name.empty() ? makeNameForAnonInstr(Rule) : StringRef(Name));
336}
337
338StringRef makeNameForAnonPredicate(CombineRule &Rule) {
339  return insertStrTab(to_string(
340      format("__anonpred%" PRIu64 "_%u", Rule.getID(), Rule.allocUID())));
341}
342
343void CombineRule::declareMatchData(StringRef PatternSymbol, StringRef Type,
344                                   StringRef VarName) {
345  MatchDataDecls.emplace_back(PatternSymbol, Type, VarName);
346}
347
348bool CombineRule::parseDefs() {
349  NamedRegionTimer T("parseDefs", "Time spent parsing the defs", "Rule Parsing",
350                     "Time spent on rule parsing", TimeRegions);
351  DagInit *Defs = TheDef.getValueAsDag("Defs");
352
353  if (Defs->getOperatorAsDef(TheDef.getLoc())->getName() != "defs") {
354    PrintError(TheDef.getLoc(), "Expected defs operator");
355    return false;
356  }
357
358  for (unsigned I = 0, E = Defs->getNumArgs(); I < E; ++I) {
359    // Roots should be collected into Roots
360    if (isSpecificDef(*Defs->getArg(I), "root")) {
361      Roots.emplace_back(Defs->getArgNameStr(I));
362      continue;
363    }
364
365    // Subclasses of GIDefMatchData should declare that this rule needs to pass
366    // data from the match stage to the apply stage, and ensure that the
367    // generated matcher has a suitable variable for it to do so.
368    if (Record *MatchDataRec =
369            getDefOfSubClass(*Defs->getArg(I), "GIDefMatchData")) {
370      declareMatchData(Defs->getArgNameStr(I),
371                       MatchDataRec->getValueAsString("Type"),
372                       llvm::to_string(llvm::format("MatchData%" PRIu64, ID)));
373      continue;
374    }
375
376    // Otherwise emit an appropriate error message.
377    if (getDefOfSubClass(*Defs->getArg(I), "GIDefKind"))
378      PrintError(TheDef.getLoc(),
379                 "This GIDefKind not implemented in tablegen");
380    else if (getDefOfSubClass(*Defs->getArg(I), "GIDefKindWithArgs"))
381      PrintError(TheDef.getLoc(),
382                 "This GIDefKindWithArgs not implemented in tablegen");
383    else
384      PrintError(TheDef.getLoc(),
385                 "Expected a subclass of GIDefKind or a sub-dag whose "
386                 "operator is of type GIDefKindWithArgs");
387    return false;
388  }
389
390  if (Roots.empty()) {
391    PrintError(TheDef.getLoc(), "Combine rules must have at least one root");
392    return false;
393  }
394  return true;
395}
396
397// Parse an (Instruction $a:Arg1, $b:Arg2, ...) matcher. Edges are formed
398// between matching operand names between different matchers.
399bool CombineRule::parseInstructionMatcher(
400    const CodeGenTarget &Target, StringInit *ArgName, const Init &Arg,
401    StringMap<std::vector<VarInfo>> &NamedEdgeDefs,
402    StringMap<std::vector<VarInfo>> &NamedEdgeUses) {
403  if (const DagInit *Matcher =
404          getDagWithOperatorOfSubClass(Arg, "Instruction")) {
405    auto &Instr =
406        Target.getInstruction(Matcher->getOperatorAsDef(TheDef.getLoc()));
407
408    StringRef Name = ArgName ? ArgName->getValue() : "";
409
410    GIMatchDagInstr *N =
411        MatchDag.addInstrNode(makeDebugName(*this, Name), insertStrTab(Name),
412                              MatchDag.getContext().makeOperandList(Instr));
413
414    N->setOpcodeAnnotation(&Instr);
415    const auto &P = MatchDag.addPredicateNode<GIMatchDagOpcodePredicate>(
416        makeNameForAnonPredicate(*this), Instr);
417    MatchDag.addPredicateDependency(N, nullptr, P, &P->getOperandInfo()["mi"]);
418    unsigned OpIdx = 0;
419    for (const auto &NameInit : Matcher->getArgNames()) {
420      StringRef Name = insertStrTab(NameInit->getAsUnquotedString());
421      if (Name.empty())
422        continue;
423      N->assignNameToOperand(OpIdx, Name);
424
425      // Record the endpoints of any named edges. We'll add the cartesian
426      // product of edges later.
427      const auto &InstrOperand = N->getOperandInfo()[OpIdx];
428      if (InstrOperand.isDef()) {
429        NamedEdgeDefs.try_emplace(Name);
430        NamedEdgeDefs[Name].emplace_back(N, &InstrOperand, Matcher);
431      } else {
432        NamedEdgeUses.try_emplace(Name);
433        NamedEdgeUses[Name].emplace_back(N, &InstrOperand, Matcher);
434      }
435
436      if (InstrOperand.isDef()) {
437        if (find_if(Roots, [&](const RootInfo &X) {
438              return X.getPatternSymbol() == Name;
439            }) != Roots.end()) {
440          N->setMatchRoot();
441        }
442      }
443
444      OpIdx++;
445    }
446
447    return true;
448  }
449  return false;
450}
451
452// Parse the wip_match_opcode placeholder that's temporarily present in lieu of
453// implementing macros or choices between two matchers.
454bool CombineRule::parseWipMatchOpcodeMatcher(const CodeGenTarget &Target,
455                                             StringInit *ArgName,
456                                             const Init &Arg) {
457  if (const DagInit *Matcher =
458          getDagWithSpecificOperator(Arg, "wip_match_opcode")) {
459    StringRef Name = ArgName ? ArgName->getValue() : "";
460
461    GIMatchDagInstr *N =
462        MatchDag.addInstrNode(makeDebugName(*this, Name), insertStrTab(Name),
463                              MatchDag.getContext().makeEmptyOperandList());
464
465    if (find_if(Roots, [&](const RootInfo &X) {
466          return ArgName && X.getPatternSymbol() == ArgName->getValue();
467        }) != Roots.end()) {
468      N->setMatchRoot();
469    }
470
471    const auto &P = MatchDag.addPredicateNode<GIMatchDagOneOfOpcodesPredicate>(
472        makeNameForAnonPredicate(*this));
473    MatchDag.addPredicateDependency(N, nullptr, P, &P->getOperandInfo()["mi"]);
474    // Each argument is an opcode that will pass this predicate. Add them all to
475    // the predicate implementation
476    for (const auto &Arg : Matcher->getArgs()) {
477      Record *OpcodeDef = getDefOfSubClass(*Arg, "Instruction");
478      if (OpcodeDef) {
479        P->addOpcode(&Target.getInstruction(OpcodeDef));
480        continue;
481      }
482      PrintError(TheDef.getLoc(),
483                 "Arguments to wip_match_opcode must be instructions");
484      return false;
485    }
486    return true;
487  }
488  return false;
489}
490bool CombineRule::parseMatcher(const CodeGenTarget &Target) {
491  NamedRegionTimer T("parseMatcher", "Time spent parsing the matcher",
492                     "Rule Parsing", "Time spent on rule parsing", TimeRegions);
493  StringMap<std::vector<VarInfo>> NamedEdgeDefs;
494  StringMap<std::vector<VarInfo>> NamedEdgeUses;
495  DagInit *Matchers = TheDef.getValueAsDag("Match");
496
497  if (Matchers->getOperatorAsDef(TheDef.getLoc())->getName() != "match") {
498    PrintError(TheDef.getLoc(), "Expected match operator");
499    return false;
500  }
501
502  if (Matchers->getNumArgs() == 0) {
503    PrintError(TheDef.getLoc(), "Matcher is empty");
504    return false;
505  }
506
507  // The match section consists of a list of matchers and predicates. Parse each
508  // one and add the equivalent GIMatchDag nodes, predicates, and edges.
509  for (unsigned I = 0; I < Matchers->getNumArgs(); ++I) {
510    if (parseInstructionMatcher(Target, Matchers->getArgName(I),
511                                *Matchers->getArg(I), NamedEdgeDefs,
512                                NamedEdgeUses))
513      continue;
514
515    if (parseWipMatchOpcodeMatcher(Target, Matchers->getArgName(I),
516                                   *Matchers->getArg(I)))
517      continue;
518
519
520    // Parse arbitrary C++ code we have in lieu of supporting MIR matching
521    if (const CodeInit *CodeI = dyn_cast<CodeInit>(Matchers->getArg(I))) {
522      assert(!MatchingFixupCode &&
523             "Only one block of arbitrary code is currently permitted");
524      MatchingFixupCode = CodeI;
525      MatchDag.setHasPostMatchPredicate(true);
526      continue;
527    }
528
529    PrintError(TheDef.getLoc(),
530               "Expected a subclass of GIMatchKind or a sub-dag whose "
531               "operator is either of a GIMatchKindWithArgs or Instruction");
532    PrintNote("Pattern was `" + Matchers->getArg(I)->getAsString() + "'");
533    return false;
534  }
535
536  // Add the cartesian product of use -> def edges.
537  bool FailedToAddEdges = false;
538  for (const auto &NameAndDefs : NamedEdgeDefs) {
539    if (NameAndDefs.getValue().size() > 1) {
540      PrintError(TheDef.getLoc(),
541                 "Two different MachineInstrs cannot def the same vreg");
542      for (const auto &NameAndDefOp : NameAndDefs.getValue())
543        PrintNote("in " + to_string(*NameAndDefOp.N) + " created from " +
544                  to_string(*NameAndDefOp.Matcher) + "");
545      FailedToAddEdges = true;
546    }
547    const auto &Uses = NamedEdgeUses[NameAndDefs.getKey()];
548    for (const VarInfo &DefVar : NameAndDefs.getValue()) {
549      for (const VarInfo &UseVar : Uses) {
550        MatchDag.addEdge(insertStrTab(NameAndDefs.getKey()), UseVar.N, UseVar.Op,
551                         DefVar.N, DefVar.Op);
552      }
553    }
554  }
555  if (FailedToAddEdges)
556    return false;
557
558  // If a variable is referenced in multiple use contexts then we need a
559  // predicate to confirm they are the same operand. We can elide this if it's
560  // also referenced in a def context and we're traversing the def-use chain
561  // from the def to the uses but we can't know which direction we're going
562  // until after reorientToRoots().
563  for (const auto &NameAndUses : NamedEdgeUses) {
564    const auto &Uses = NameAndUses.getValue();
565    if (Uses.size() > 1) {
566      const auto &LeadingVar = Uses.front();
567      for (const auto &Var : ArrayRef<VarInfo>(Uses).drop_front()) {
568        // Add a predicate for each pair until we've covered the whole
569        // equivalence set. We could test the whole set in a single predicate
570        // but that means we can't test any equivalence until all the MO's are
571        // available which can lead to wasted work matching the DAG when this
572        // predicate can already be seen to have failed.
573        //
574        // We have a similar problem due to the need to wait for a particular MO
575        // before being able to test any of them. However, that is mitigated by
576        // the order in which we build the DAG. We build from the roots outwards
577        // so by using the first recorded use in all the predicates, we are
578        // making the dependency on one of the earliest visited references in
579        // the DAG. It's not guaranteed once the generated matcher is optimized
580        // (because the factoring the common portions of rules might change the
581        // visit order) but this should mean that these predicates depend on the
582        // first MO to become available.
583        const auto &P = MatchDag.addPredicateNode<GIMatchDagSameMOPredicate>(
584            makeNameForAnonPredicate(*this));
585        MatchDag.addPredicateDependency(LeadingVar.N, LeadingVar.Op, P,
586                                        &P->getOperandInfo()["mi0"]);
587        MatchDag.addPredicateDependency(Var.N, Var.Op, P,
588                                        &P->getOperandInfo()["mi1"]);
589      }
590    }
591  }
592  return true;
593}
594
595class GICombinerEmitter {
596  StringRef Name;
597  const CodeGenTarget &Target;
598  Record *Combiner;
599  std::vector<std::unique_ptr<CombineRule>> Rules;
600  GIMatchDagContext MatchDagCtx;
601
602  std::unique_ptr<CombineRule> makeCombineRule(const Record &R);
603
604  void gatherRules(std::vector<std::unique_ptr<CombineRule>> &ActiveRules,
605                   const std::vector<Record *> &&RulesAndGroups);
606
607public:
608  explicit GICombinerEmitter(RecordKeeper &RK, const CodeGenTarget &Target,
609                             StringRef Name, Record *Combiner);
610  ~GICombinerEmitter() {}
611
612  StringRef getClassName() const {
613    return Combiner->getValueAsString("Classname");
614  }
615  void run(raw_ostream &OS);
616
617  /// Emit the name matcher (guarded by #ifndef NDEBUG) used to disable rules in
618  /// response to the generated cl::opt.
619  void emitNameMatcher(raw_ostream &OS) const;
620
621  void generateDeclarationsCodeForTree(raw_ostream &OS, const GIMatchTree &Tree) const;
622  void generateCodeForTree(raw_ostream &OS, const GIMatchTree &Tree,
623                           StringRef Indent) const;
624};
625
626GICombinerEmitter::GICombinerEmitter(RecordKeeper &RK,
627                                     const CodeGenTarget &Target,
628                                     StringRef Name, Record *Combiner)
629    : Name(Name), Target(Target), Combiner(Combiner) {}
630
631void GICombinerEmitter::emitNameMatcher(raw_ostream &OS) const {
632  std::vector<std::pair<std::string, std::string>> Cases;
633  Cases.reserve(Rules.size());
634
635  for (const CombineRule &EnumeratedRule : make_pointee_range(Rules)) {
636    std::string Code;
637    raw_string_ostream SS(Code);
638    SS << "return " << EnumeratedRule.getID() << ";\n";
639    Cases.push_back(std::make_pair(EnumeratedRule.getName(), SS.str()));
640  }
641
642  OS << "static Optional<uint64_t> getRuleIdxForIdentifier(StringRef "
643        "RuleIdentifier) {\n"
644     << "  uint64_t I;\n"
645     << "  // getAtInteger(...) returns false on success\n"
646     << "  bool Parsed = !RuleIdentifier.getAsInteger(0, I);\n"
647     << "  if (Parsed)\n"
648     << "    return I;\n\n"
649     << "#ifndef NDEBUG\n";
650  StringMatcher Matcher("RuleIdentifier", Cases, OS);
651  Matcher.Emit();
652  OS << "#endif // ifndef NDEBUG\n\n"
653     << "  return None;\n"
654     << "}\n";
655}
656
657std::unique_ptr<CombineRule>
658GICombinerEmitter::makeCombineRule(const Record &TheDef) {
659  std::unique_ptr<CombineRule> Rule =
660      std::make_unique<CombineRule>(Target, MatchDagCtx, NumPatternTotal, TheDef);
661
662  if (!Rule->parseDefs())
663    return nullptr;
664  if (!Rule->parseMatcher(Target))
665    return nullptr;
666
667  Rule->reorientToRoots();
668
669  LLVM_DEBUG({
670    dbgs() << "Parsed rule defs/match for '" << Rule->getName() << "'\n";
671    Rule->getMatchDag().dump();
672    Rule->getMatchDag().writeDOTGraph(dbgs(), Rule->getName());
673  });
674  if (StopAfterParse)
675    return Rule;
676
677  // For now, don't support traversing from def to use. We'll come back to
678  // this later once we have the algorithm changes to support it.
679  bool EmittedDefToUseError = false;
680  for (const auto &E : Rule->getMatchDag().edges()) {
681    if (E->isDefToUse()) {
682      if (!EmittedDefToUseError) {
683        PrintError(
684            TheDef.getLoc(),
685            "Generated state machine cannot lookup uses from a def (yet)");
686        EmittedDefToUseError = true;
687      }
688      PrintNote("Node " + to_string(*E->getFromMI()));
689      PrintNote("Node " + to_string(*E->getToMI()));
690      PrintNote("Edge " + to_string(*E));
691    }
692  }
693  if (EmittedDefToUseError)
694    return nullptr;
695
696  // For now, don't support multi-root rules. We'll come back to this later
697  // once we have the algorithm changes to support it.
698  if (Rule->getNumRoots() > 1) {
699    PrintError(TheDef.getLoc(), "Multi-root matches are not supported (yet)");
700    return nullptr;
701  }
702  return Rule;
703}
704
705/// Recurse into GICombineGroup's and flatten the ruleset into a simple list.
706void GICombinerEmitter::gatherRules(
707    std::vector<std::unique_ptr<CombineRule>> &ActiveRules,
708    const std::vector<Record *> &&RulesAndGroups) {
709  for (Record *R : RulesAndGroups) {
710    if (R->isValueUnset("Rules")) {
711      std::unique_ptr<CombineRule> Rule = makeCombineRule(*R);
712      if (Rule == nullptr) {
713        PrintError(R->getLoc(), "Failed to parse rule");
714        continue;
715      }
716      ActiveRules.emplace_back(std::move(Rule));
717      ++NumPatternTotal;
718    } else
719      gatherRules(ActiveRules, R->getValueAsListOfDefs("Rules"));
720  }
721}
722
723void GICombinerEmitter::generateCodeForTree(raw_ostream &OS,
724                                            const GIMatchTree &Tree,
725                                            StringRef Indent) const {
726  if (Tree.getPartitioner() != nullptr) {
727    Tree.getPartitioner()->generatePartitionSelectorCode(OS, Indent);
728    for (const auto &EnumChildren : enumerate(Tree.children())) {
729      OS << Indent << "if (Partition == " << EnumChildren.index() << " /* "
730         << format_partition_name(Tree, EnumChildren.index()) << " */) {\n";
731      generateCodeForTree(OS, EnumChildren.value(), (Indent + "  ").str());
732      OS << Indent << "}\n";
733    }
734    return;
735  }
736
737  bool AnyFullyTested = false;
738  for (const auto &Leaf : Tree.possible_leaves()) {
739    OS << Indent << "// Leaf name: " << Leaf.getName() << "\n";
740
741    const CombineRule *Rule = Leaf.getTargetData<CombineRule>();
742    const Record &RuleDef = Rule->getDef();
743
744    OS << Indent << "// Rule: " << RuleDef.getName() << "\n"
745       << Indent << "if (!isRuleDisabled(" << Rule->getID() << ")) {\n";
746
747    CodeExpansions Expansions;
748    for (const auto &VarBinding : Leaf.var_bindings()) {
749      if (VarBinding.isInstr())
750        Expansions.declare(VarBinding.getName(),
751                           "MIs[" + to_string(VarBinding.getInstrID()) + "]");
752      else
753        Expansions.declare(VarBinding.getName(),
754                           "MIs[" + to_string(VarBinding.getInstrID()) +
755                               "]->getOperand(" +
756                               to_string(VarBinding.getOpIdx()) + ")");
757    }
758    Rule->declareExpansions(Expansions);
759
760    DagInit *Applyer = RuleDef.getValueAsDag("Apply");
761    if (Applyer->getOperatorAsDef(RuleDef.getLoc())->getName() !=
762        "apply") {
763      PrintError(RuleDef.getLoc(), "Expected apply operator");
764      return;
765    }
766
767    OS << Indent << "  if (1\n";
768
769    // Attempt to emit code for any untested predicates left over. Note that
770    // isFullyTested() will remain false even if we succeed here and therefore
771    // combine rule elision will not be performed. This is because we do not
772    // know if there's any connection between the predicates for each leaf and
773    // therefore can't tell if one makes another unreachable. Ideally, the
774    // partitioner(s) would be sufficiently complete to prevent us from having
775    // untested predicates left over.
776    for (const GIMatchDagPredicate *Predicate : Leaf.untested_predicates()) {
777      if (Predicate->generateCheckCode(OS, (Indent + "      ").str(),
778                                       Expansions))
779        continue;
780      PrintError(RuleDef.getLoc(),
781                 "Unable to test predicate used in rule");
782      PrintNote(SMLoc(),
783                "This indicates an incomplete implementation in tablegen");
784      Predicate->print(errs());
785      errs() << "\n";
786      OS << Indent
787         << "llvm_unreachable(\"TableGen did not emit complete code for this "
788            "path\");\n";
789      break;
790    }
791
792    if (Rule->getMatchingFixupCode() &&
793        !Rule->getMatchingFixupCode()->getValue().empty()) {
794      // FIXME: Single-use lambda's like this are a serious compile-time
795      // performance and memory issue. It's convenient for this early stage to
796      // defer some work to successive patches but we need to eliminate this
797      // before the ruleset grows to small-moderate size. Last time, it became
798      // a big problem for low-mem systems around the 500 rule mark but by the
799      // time we grow that large we should have merged the ISel match table
800      // mechanism with the Combiner.
801      OS << Indent << "      && [&]() {\n"
802         << Indent << "      "
803         << CodeExpander(Rule->getMatchingFixupCode()->getValue(), Expansions,
804                         Rule->getMatchingFixupCode()->getLoc(), ShowExpansions)
805         << "\n"
806         << Indent << "      return true;\n"
807         << Indent << "  }()";
808    }
809    OS << ") {\n" << Indent << "   ";
810
811    if (const CodeInit *Code = dyn_cast<CodeInit>(Applyer->getArg(0))) {
812      OS << CodeExpander(Code->getAsUnquotedString(), Expansions,
813                         Code->getLoc(), ShowExpansions)
814         << "\n"
815         << Indent << "    return true;\n"
816         << Indent << "  }\n";
817    } else {
818      PrintError(RuleDef.getLoc(), "Expected apply code block");
819      return;
820    }
821
822    OS << Indent << "}\n";
823
824    assert(Leaf.isFullyTraversed());
825
826    // If we didn't have any predicates left over and we're not using the
827    // trap-door we have to support arbitrary C++ code while we're migrating to
828    // the declarative style then we know that subsequent leaves are
829    // unreachable.
830    if (Leaf.isFullyTested() &&
831        (!Rule->getMatchingFixupCode() ||
832         Rule->getMatchingFixupCode()->getValue().empty())) {
833      AnyFullyTested = true;
834      OS << Indent
835         << "llvm_unreachable(\"Combine rule elision was incorrect\");\n"
836         << Indent << "return false;\n";
837    }
838  }
839  if (!AnyFullyTested)
840    OS << Indent << "return false;\n";
841}
842
843void GICombinerEmitter::run(raw_ostream &OS) {
844  gatherRules(Rules, Combiner->getValueAsListOfDefs("Rules"));
845  if (StopAfterParse) {
846    MatchDagCtx.print(errs());
847    PrintNote(Combiner->getLoc(),
848              "Terminating due to -gicombiner-stop-after-parse");
849    return;
850  }
851  if (ErrorsPrinted)
852    PrintFatalError(Combiner->getLoc(), "Failed to parse one or more rules");
853  LLVM_DEBUG(dbgs() << "Optimizing tree for " << Rules.size() << " rules\n");
854  std::unique_ptr<GIMatchTree> Tree;
855  {
856    NamedRegionTimer T("Optimize", "Time spent optimizing the combiner",
857                       "Code Generation", "Time spent generating code",
858                       TimeRegions);
859
860    GIMatchTreeBuilder TreeBuilder(0);
861    for (const auto &Rule : Rules) {
862      bool HadARoot = false;
863      for (const auto &Root : enumerate(Rule->getMatchDag().roots())) {
864        TreeBuilder.addLeaf(Rule->getName(), Root.index(), Rule->getMatchDag(),
865                            Rule.get());
866        HadARoot = true;
867      }
868      if (!HadARoot)
869        PrintFatalError(Rule->getDef().getLoc(), "All rules must have a root");
870    }
871
872    Tree = TreeBuilder.run();
873  }
874  if (StopAfterBuild) {
875    Tree->writeDOTGraph(outs());
876    PrintNote(Combiner->getLoc(),
877              "Terminating due to -gicombiner-stop-after-build");
878    return;
879  }
880
881  NamedRegionTimer T("Emit", "Time spent emitting the combiner",
882                     "Code Generation", "Time spent generating code",
883                     TimeRegions);
884  OS << "#ifdef " << Name.upper() << "_GENCOMBINERHELPER_DEPS\n"
885     << "#include \"llvm/ADT/SparseBitVector.h\"\n"
886     << "namespace llvm {\n"
887     << "extern cl::OptionCategory GICombinerOptionCategory;\n"
888     << "} // end namespace llvm\n"
889     << "#endif // ifdef " << Name.upper() << "_GENCOMBINERHELPER_DEPS\n\n";
890
891  OS << "#ifdef " << Name.upper() << "_GENCOMBINERHELPER_H\n"
892     << "class " << getClassName() << " {\n"
893     << "  SparseBitVector<> DisabledRules;\n"
894     << "\n"
895     << "public:\n"
896     << "  bool parseCommandLineOption();\n"
897     << "  bool isRuleDisabled(unsigned ID) const;\n"
898     << "  bool setRuleDisabled(StringRef RuleIdentifier);\n"
899     << "\n"
900     << "  bool tryCombineAll(\n"
901     << "    GISelChangeObserver &Observer,\n"
902     << "    MachineInstr &MI,\n"
903     << "    MachineIRBuilder &B,\n"
904     << "    CombinerHelper &Helper) const;\n"
905     << "};\n\n";
906
907  emitNameMatcher(OS);
908
909  OS << "bool " << getClassName()
910     << "::setRuleDisabled(StringRef RuleIdentifier) {\n"
911     << "  std::pair<StringRef, StringRef> RangePair = "
912        "RuleIdentifier.split('-');\n"
913     << "  if (!RangePair.second.empty()) {\n"
914     << "    const auto First = getRuleIdxForIdentifier(RangePair.first);\n"
915     << "    const auto Last = getRuleIdxForIdentifier(RangePair.second);\n"
916     << "    if (!First.hasValue() || !Last.hasValue())\n"
917     << "      return false;\n"
918     << "    if (First >= Last)\n"
919     << "      report_fatal_error(\"Beginning of range should be before end of "
920        "range\");\n"
921     << "    for (auto I = First.getValue(); I < Last.getValue(); ++I)\n"
922     << "      DisabledRules.set(I);\n"
923     << "    return true;\n"
924     << "  } else {\n"
925     << "    const auto I = getRuleIdxForIdentifier(RangePair.first);\n"
926     << "    if (!I.hasValue())\n"
927     << "      return false;\n"
928     << "    DisabledRules.set(I.getValue());\n"
929     << "    return true;\n"
930     << "  }\n"
931     << "  return false;\n"
932     << "}\n";
933
934  OS << "bool " << getClassName()
935     << "::isRuleDisabled(unsigned RuleID) const {\n"
936     << "  return DisabledRules.test(RuleID);\n"
937     << "}\n";
938  OS << "#endif // ifdef " << Name.upper() << "_GENCOMBINERHELPER_H\n\n";
939
940  OS << "#ifdef " << Name.upper() << "_GENCOMBINERHELPER_CPP\n"
941     << "\n"
942     << "cl::list<std::string> " << Name << "Option(\n"
943     << "    \"" << Name.lower() << "-disable-rule\",\n"
944     << "    cl::desc(\"Disable one or more combiner rules temporarily in "
945     << "the " << Name << " pass\"),\n"
946     << "    cl::CommaSeparated,\n"
947     << "    cl::Hidden,\n"
948     << "    cl::cat(GICombinerOptionCategory));\n"
949     << "\n"
950     << "bool " << getClassName() << "::parseCommandLineOption() {\n"
951     << "  for (const auto &Identifier : " << Name << "Option)\n"
952     << "    if (!setRuleDisabled(Identifier))\n"
953     << "      return false;\n"
954     << "  return true;\n"
955     << "}\n\n";
956
957  OS << "bool " << getClassName() << "::tryCombineAll(\n"
958     << "    GISelChangeObserver &Observer,\n"
959     << "    MachineInstr &MI,\n"
960     << "    MachineIRBuilder &B,\n"
961     << "    CombinerHelper &Helper) const {\n"
962     << "  MachineBasicBlock *MBB = MI.getParent();\n"
963     << "  MachineFunction *MF = MBB->getParent();\n"
964     << "  MachineRegisterInfo &MRI = MF->getRegInfo();\n"
965     << "  SmallVector<MachineInstr *, 8> MIs = { &MI };\n\n"
966     << "  (void)MBB; (void)MF; (void)MRI;\n\n";
967
968  OS << "  // Match data\n";
969  for (const auto &Rule : Rules)
970    for (const auto &I : Rule->matchdata_decls())
971      OS << "  " << I.getType() << " " << I.getVariableName() << ";\n";
972  OS << "\n";
973
974  OS << "  int Partition = -1;\n";
975  generateCodeForTree(OS, *Tree, "  ");
976  OS << "\n  return false;\n"
977     << "}\n"
978     << "#endif // ifdef " << Name.upper() << "_GENCOMBINERHELPER_CPP\n";
979}
980
981} // end anonymous namespace
982
983//===----------------------------------------------------------------------===//
984
985namespace llvm {
986void EmitGICombiner(RecordKeeper &RK, raw_ostream &OS) {
987  CodeGenTarget Target(RK);
988  emitSourceFileHeader("Global Combiner", OS);
989
990  if (SelectedCombiners.empty())
991    PrintFatalError("No combiners selected with -combiners");
992  for (const auto &Combiner : SelectedCombiners) {
993    Record *CombinerDef = RK.getDef(Combiner);
994    if (!CombinerDef)
995      PrintFatalError("Could not find " + Combiner);
996    GICombinerEmitter(RK, Target, Combiner, CombinerDef).run(OS);
997  }
998  NumPatternTotalStatistic = NumPatternTotal;
999}
1000
1001} // namespace llvm
1002