1//===- VarLenCodeEmitterGen.cpp - CEG for variable-length insts -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// The CodeEmitterGen component for variable-length instructions.
10//
11// The basic CodeEmitterGen is almost exclusively designed for fixed-
12// length instructions. A good analogy for its encoding scheme is how printf
13// works: The (immutable) formatting string represent the fixed values in the
14// encoded instruction. Placeholders (i.e. %something), on the other hand,
15// represent encoding for instruction operands.
16// ```
17// printf("1101 %src 1001 %dst", <encoded value for operand `src`>,
18//                               <encoded value for operand `dst`>);
19// ```
20// VarLenCodeEmitterGen in this file provides an alternative encoding scheme
21// that works more like a C++ stream operator:
22// ```
23// OS << 0b1101;
24// if (Cond)
25//   OS << OperandEncoding0;
26// OS << 0b1001 << OperandEncoding1;
27// ```
28// You are free to concatenate arbitrary types (and sizes) of encoding
29// fragments on any bit position, bringing more flexibilities on defining
30// encoding for variable-length instructions.
31//
32// In a more specific way, instruction encoding is represented by a DAG type
33// `Inst` field. Here is an example:
34// ```
35// dag Inst = (descend 0b1101, (operand "$src", 4), 0b1001,
36//                     (operand "$dst", 4));
37// ```
38// It represents the following instruction encoding:
39// ```
40// MSB                                                     LSB
41// 1101<encoding for operand src>1001<encoding for operand dst>
42// ```
43// For more details about DAG operators in the above snippet, please
44// refer to \file include/llvm/Target/Target.td.
45//
46// VarLenCodeEmitter will convert the above DAG into the same helper function
47// generated by CodeEmitter, `MCCodeEmitter::getBinaryCodeForInstr` (except
48// for few details).
49//
50//===----------------------------------------------------------------------===//
51
52#include "VarLenCodeEmitterGen.h"
53#include "CodeGenHwModes.h"
54#include "CodeGenInstruction.h"
55#include "CodeGenTarget.h"
56#include "InfoByHwMode.h"
57#include "llvm/ADT/ArrayRef.h"
58#include "llvm/ADT/DenseMap.h"
59#include "llvm/Support/raw_ostream.h"
60#include "llvm/TableGen/Error.h"
61
62using namespace llvm;
63
64namespace {
65
66class VarLenCodeEmitterGen {
67  RecordKeeper &Records;
68
69  DenseMap<Record *, VarLenInst> VarLenInsts;
70
71  // Emit based values (i.e. fixed bits in the encoded instructions)
72  void emitInstructionBaseValues(
73      raw_ostream &OS,
74      ArrayRef<const CodeGenInstruction *> NumberedInstructions,
75      CodeGenTarget &Target, int HwMode = -1);
76
77  std::string getInstructionCase(Record *R, CodeGenTarget &Target);
78  std::string getInstructionCaseForEncoding(Record *R, Record *EncodingDef,
79                                            CodeGenTarget &Target);
80
81public:
82  explicit VarLenCodeEmitterGen(RecordKeeper &R) : Records(R) {}
83
84  void run(raw_ostream &OS);
85};
86} // end anonymous namespace
87
88// Get the name of custom encoder or decoder, if there is any.
89// Returns `{encoder name, decoder name}`.
90static std::pair<StringRef, StringRef> getCustomCoders(ArrayRef<Init *> Args) {
91  std::pair<StringRef, StringRef> Result;
92  for (const auto *Arg : Args) {
93    const auto *DI = dyn_cast<DagInit>(Arg);
94    if (!DI)
95      continue;
96    const Init *Op = DI->getOperator();
97    if (!isa<DefInit>(Op))
98      continue;
99    // syntax: `(<encoder | decoder> "function name")`
100    StringRef OpName = cast<DefInit>(Op)->getDef()->getName();
101    if (OpName != "encoder" && OpName != "decoder")
102      continue;
103    if (!DI->getNumArgs() || !isa<StringInit>(DI->getArg(0)))
104      PrintFatalError("expected '" + OpName +
105                      "' directive to be followed by a custom function name.");
106    StringRef FuncName = cast<StringInit>(DI->getArg(0))->getValue();
107    if (OpName == "encoder")
108      Result.first = FuncName;
109    else
110      Result.second = FuncName;
111  }
112  return Result;
113}
114
115VarLenInst::VarLenInst(const DagInit *DI, const RecordVal *TheDef)
116    : TheDef(TheDef), NumBits(0U) {
117  buildRec(DI);
118  for (const auto &S : Segments)
119    NumBits += S.BitWidth;
120}
121
122void VarLenInst::buildRec(const DagInit *DI) {
123  assert(TheDef && "The def record is nullptr ?");
124
125  std::string Op = DI->getOperator()->getAsString();
126
127  if (Op == "ascend" || Op == "descend") {
128    bool Reverse = Op == "descend";
129    int i = Reverse ? DI->getNumArgs() - 1 : 0;
130    int e = Reverse ? -1 : DI->getNumArgs();
131    int s = Reverse ? -1 : 1;
132    for (; i != e; i += s) {
133      const Init *Arg = DI->getArg(i);
134      if (const auto *BI = dyn_cast<BitsInit>(Arg)) {
135        if (!BI->isComplete())
136          PrintFatalError(TheDef->getLoc(),
137                          "Expecting complete bits init in `" + Op + "`");
138        Segments.push_back({BI->getNumBits(), BI});
139      } else if (const auto *BI = dyn_cast<BitInit>(Arg)) {
140        if (!BI->isConcrete())
141          PrintFatalError(TheDef->getLoc(),
142                          "Expecting concrete bit init in `" + Op + "`");
143        Segments.push_back({1, BI});
144      } else if (const auto *SubDI = dyn_cast<DagInit>(Arg)) {
145        buildRec(SubDI);
146      } else {
147        PrintFatalError(TheDef->getLoc(), "Unrecognized type of argument in `" +
148                                              Op + "`: " + Arg->getAsString());
149      }
150    }
151  } else if (Op == "operand") {
152    // (operand <operand name>, <# of bits>,
153    //          [(encoder <custom encoder>)][, (decoder <custom decoder>)])
154    if (DI->getNumArgs() < 2)
155      PrintFatalError(TheDef->getLoc(),
156                      "Expecting at least 2 arguments for `operand`");
157    HasDynamicSegment = true;
158    const Init *OperandName = DI->getArg(0), *NumBits = DI->getArg(1);
159    if (!isa<StringInit>(OperandName) || !isa<IntInit>(NumBits))
160      PrintFatalError(TheDef->getLoc(), "Invalid argument types for `operand`");
161
162    auto NumBitsVal = cast<IntInit>(NumBits)->getValue();
163    if (NumBitsVal <= 0)
164      PrintFatalError(TheDef->getLoc(), "Invalid number of bits for `operand`");
165
166    auto [CustomEncoder, CustomDecoder] =
167        getCustomCoders(DI->getArgs().slice(2));
168    Segments.push_back({static_cast<unsigned>(NumBitsVal), OperandName,
169                        CustomEncoder, CustomDecoder});
170  } else if (Op == "slice") {
171    // (slice <operand name>, <high / low bit>, <low / high bit>,
172    //        [(encoder <custom encoder>)][, (decoder <custom decoder>)])
173    if (DI->getNumArgs() < 3)
174      PrintFatalError(TheDef->getLoc(),
175                      "Expecting at least 3 arguments for `slice`");
176    HasDynamicSegment = true;
177    Init *OperandName = DI->getArg(0), *HiBit = DI->getArg(1),
178         *LoBit = DI->getArg(2);
179    if (!isa<StringInit>(OperandName) || !isa<IntInit>(HiBit) ||
180        !isa<IntInit>(LoBit))
181      PrintFatalError(TheDef->getLoc(), "Invalid argument types for `slice`");
182
183    auto HiBitVal = cast<IntInit>(HiBit)->getValue(),
184         LoBitVal = cast<IntInit>(LoBit)->getValue();
185    if (HiBitVal < 0 || LoBitVal < 0)
186      PrintFatalError(TheDef->getLoc(), "Invalid bit range for `slice`");
187    bool NeedSwap = false;
188    unsigned NumBits = 0U;
189    if (HiBitVal < LoBitVal) {
190      NeedSwap = true;
191      NumBits = static_cast<unsigned>(LoBitVal - HiBitVal + 1);
192    } else {
193      NumBits = static_cast<unsigned>(HiBitVal - LoBitVal + 1);
194    }
195
196    auto [CustomEncoder, CustomDecoder] =
197        getCustomCoders(DI->getArgs().slice(3));
198
199    if (NeedSwap) {
200      // Normalization: Hi bit should always be the second argument.
201      Init *const NewArgs[] = {OperandName, LoBit, HiBit};
202      Segments.push_back({NumBits,
203                          DagInit::get(DI->getOperator(), nullptr, NewArgs, {}),
204                          CustomEncoder, CustomDecoder});
205    } else {
206      Segments.push_back({NumBits, DI, CustomEncoder, CustomDecoder});
207    }
208  }
209}
210
211void VarLenCodeEmitterGen::run(raw_ostream &OS) {
212  CodeGenTarget Target(Records);
213  auto Insts = Records.getAllDerivedDefinitions("Instruction");
214
215  auto NumberedInstructions = Target.getInstructionsByEnumValue();
216  const CodeGenHwModes &HWM = Target.getHwModes();
217
218  // The set of HwModes used by instruction encodings.
219  std::set<unsigned> HwModes;
220  for (const CodeGenInstruction *CGI : NumberedInstructions) {
221    Record *R = CGI->TheDef;
222
223    // Create the corresponding VarLenInst instance.
224    if (R->getValueAsString("Namespace") == "TargetOpcode" ||
225        R->getValueAsBit("isPseudo"))
226      continue;
227
228    if (const RecordVal *RV = R->getValue("EncodingInfos")) {
229      if (auto *DI = dyn_cast_or_null<DefInit>(RV->getValue())) {
230        EncodingInfoByHwMode EBM(DI->getDef(), HWM);
231        for (auto &KV : EBM) {
232          HwModes.insert(KV.first);
233          Record *EncodingDef = KV.second;
234          RecordVal *RV = EncodingDef->getValue("Inst");
235          DagInit *DI = cast<DagInit>(RV->getValue());
236          VarLenInsts.insert({EncodingDef, VarLenInst(DI, RV)});
237        }
238        continue;
239      }
240    }
241    RecordVal *RV = R->getValue("Inst");
242    DagInit *DI = cast<DagInit>(RV->getValue());
243    VarLenInsts.insert({R, VarLenInst(DI, RV)});
244  }
245
246  // Emit function declaration
247  OS << "void " << Target.getName()
248     << "MCCodeEmitter::getBinaryCodeForInstr(const MCInst &MI,\n"
249     << "    SmallVectorImpl<MCFixup> &Fixups,\n"
250     << "    APInt &Inst,\n"
251     << "    APInt &Scratch,\n"
252     << "    const MCSubtargetInfo &STI) const {\n";
253
254  // Emit instruction base values
255  if (HwModes.empty()) {
256    emitInstructionBaseValues(OS, NumberedInstructions, Target);
257  } else {
258    for (unsigned HwMode : HwModes)
259      emitInstructionBaseValues(OS, NumberedInstructions, Target, (int)HwMode);
260  }
261
262  if (!HwModes.empty()) {
263    OS << "  const unsigned **Index;\n";
264    OS << "  const uint64_t *InstBits;\n";
265    OS << "  unsigned HwMode = STI.getHwMode();\n";
266    OS << "  switch (HwMode) {\n";
267    OS << "  default: llvm_unreachable(\"Unknown hardware mode!\"); break;\n";
268    for (unsigned I : HwModes) {
269      OS << "  case " << I << ": InstBits = InstBits_" << HWM.getMode(I).Name
270         << "; Index = Index_" << HWM.getMode(I).Name << "; break;\n";
271    }
272    OS << "  };\n";
273  }
274
275  // Emit helper function to retrieve base values.
276  OS << "  auto getInstBits = [&](unsigned Opcode) -> APInt {\n"
277     << "    unsigned NumBits = Index[Opcode][0];\n"
278     << "    if (!NumBits)\n"
279     << "      return APInt::getZeroWidth();\n"
280     << "    unsigned Idx = Index[Opcode][1];\n"
281     << "    ArrayRef<uint64_t> Data(&InstBits[Idx], "
282     << "APInt::getNumWords(NumBits));\n"
283     << "    return APInt(NumBits, Data);\n"
284     << "  };\n";
285
286  // Map to accumulate all the cases.
287  std::map<std::string, std::vector<std::string>> CaseMap;
288
289  // Construct all cases statement for each opcode
290  for (Record *R : Insts) {
291    if (R->getValueAsString("Namespace") == "TargetOpcode" ||
292        R->getValueAsBit("isPseudo"))
293      continue;
294    std::string InstName =
295        (R->getValueAsString("Namespace") + "::" + R->getName()).str();
296    std::string Case = getInstructionCase(R, Target);
297
298    CaseMap[Case].push_back(std::move(InstName));
299  }
300
301  // Emit initial function code
302  OS << "  const unsigned opcode = MI.getOpcode();\n"
303     << "  switch (opcode) {\n";
304
305  // Emit each case statement
306  for (const auto &C : CaseMap) {
307    const std::string &Case = C.first;
308    const auto &InstList = C.second;
309
310    ListSeparator LS("\n");
311    for (const auto &InstName : InstList)
312      OS << LS << "    case " << InstName << ":";
313
314    OS << " {\n";
315    OS << Case;
316    OS << "      break;\n"
317       << "    }\n";
318  }
319  // Default case: unhandled opcode
320  OS << "  default:\n"
321     << "    std::string msg;\n"
322     << "    raw_string_ostream Msg(msg);\n"
323     << "    Msg << \"Not supported instr: \" << MI;\n"
324     << "    report_fatal_error(Msg.str().c_str());\n"
325     << "  }\n";
326  OS << "}\n\n";
327}
328
329static void emitInstBits(raw_ostream &IS, raw_ostream &SS, const APInt &Bits,
330                         unsigned &Index) {
331  if (!Bits.getNumWords()) {
332    IS.indent(4) << "{/*NumBits*/0, /*Index*/0},";
333    return;
334  }
335
336  IS.indent(4) << "{/*NumBits*/" << Bits.getBitWidth() << ", "
337               << "/*Index*/" << Index << "},";
338
339  SS.indent(4);
340  for (unsigned I = 0; I < Bits.getNumWords(); ++I, ++Index)
341    SS << "UINT64_C(" << utostr(Bits.getRawData()[I]) << "),";
342}
343
344void VarLenCodeEmitterGen::emitInstructionBaseValues(
345    raw_ostream &OS, ArrayRef<const CodeGenInstruction *> NumberedInstructions,
346    CodeGenTarget &Target, int HwMode) {
347  std::string IndexArray, StorageArray;
348  raw_string_ostream IS(IndexArray), SS(StorageArray);
349
350  const CodeGenHwModes &HWM = Target.getHwModes();
351  if (HwMode == -1) {
352    IS << "  static const unsigned Index[][2] = {\n";
353    SS << "  static const uint64_t InstBits[] = {\n";
354  } else {
355    StringRef Name = HWM.getMode(HwMode).Name;
356    IS << "  static const unsigned Index_" << Name << "[][2] = {\n";
357    SS << "  static const uint64_t InstBits_" << Name << "[] = {\n";
358  }
359
360  unsigned NumFixedValueWords = 0U;
361  for (const CodeGenInstruction *CGI : NumberedInstructions) {
362    Record *R = CGI->TheDef;
363
364    if (R->getValueAsString("Namespace") == "TargetOpcode" ||
365        R->getValueAsBit("isPseudo")) {
366      IS.indent(4) << "{/*NumBits*/0, /*Index*/0},\n";
367      continue;
368    }
369
370    Record *EncodingDef = R;
371    if (const RecordVal *RV = R->getValue("EncodingInfos")) {
372      if (auto *DI = dyn_cast_or_null<DefInit>(RV->getValue())) {
373        EncodingInfoByHwMode EBM(DI->getDef(), HWM);
374        if (EBM.hasMode(HwMode))
375          EncodingDef = EBM.get(HwMode);
376      }
377    }
378
379    auto It = VarLenInsts.find(EncodingDef);
380    if (It == VarLenInsts.end())
381      PrintFatalError(EncodingDef, "VarLenInst not found for this record");
382    const VarLenInst &VLI = It->second;
383
384    unsigned i = 0U, BitWidth = VLI.size();
385
386    // Start by filling in fixed values.
387    APInt Value(BitWidth, 0);
388    auto SI = VLI.begin(), SE = VLI.end();
389    // Scan through all the segments that have fixed-bits values.
390    while (i < BitWidth && SI != SE) {
391      unsigned SegmentNumBits = SI->BitWidth;
392      if (const auto *BI = dyn_cast<BitsInit>(SI->Value)) {
393        for (unsigned Idx = 0U; Idx != SegmentNumBits; ++Idx) {
394          auto *B = cast<BitInit>(BI->getBit(Idx));
395          Value.setBitVal(i + Idx, B->getValue());
396        }
397      }
398      if (const auto *BI = dyn_cast<BitInit>(SI->Value))
399        Value.setBitVal(i, BI->getValue());
400
401      i += SegmentNumBits;
402      ++SI;
403    }
404
405    emitInstBits(IS, SS, Value, NumFixedValueWords);
406    IS << '\t' << "// " << R->getName() << "\n";
407    if (Value.getNumWords())
408      SS << '\t' << "// " << R->getName() << "\n";
409  }
410  IS.indent(4) << "{/*NumBits*/0, /*Index*/0}\n  };\n";
411  SS.indent(4) << "UINT64_C(0)\n  };\n";
412
413  OS << IS.str() << SS.str();
414}
415
416std::string VarLenCodeEmitterGen::getInstructionCase(Record *R,
417                                                     CodeGenTarget &Target) {
418  std::string Case;
419  if (const RecordVal *RV = R->getValue("EncodingInfos")) {
420    if (auto *DI = dyn_cast_or_null<DefInit>(RV->getValue())) {
421      const CodeGenHwModes &HWM = Target.getHwModes();
422      EncodingInfoByHwMode EBM(DI->getDef(), HWM);
423      Case += "      switch (HwMode) {\n";
424      Case += "      default: llvm_unreachable(\"Unhandled HwMode\");\n";
425      for (auto &KV : EBM) {
426        Case += "      case " + itostr(KV.first) + ": {\n";
427        Case += getInstructionCaseForEncoding(R, KV.second, Target);
428        Case += "      break;\n";
429        Case += "      }\n";
430      }
431      Case += "      }\n";
432      return Case;
433    }
434  }
435  return getInstructionCaseForEncoding(R, R, Target);
436}
437
438std::string VarLenCodeEmitterGen::getInstructionCaseForEncoding(
439    Record *R, Record *EncodingDef, CodeGenTarget &Target) {
440  auto It = VarLenInsts.find(EncodingDef);
441  if (It == VarLenInsts.end())
442    PrintFatalError(EncodingDef, "Parsed encoding record not found");
443  const VarLenInst &VLI = It->second;
444  size_t BitWidth = VLI.size();
445
446  CodeGenInstruction &CGI = Target.getInstruction(R);
447
448  std::string Case;
449  raw_string_ostream SS(Case);
450  // Resize the scratch buffer.
451  if (BitWidth && !VLI.isFixedValueOnly())
452    SS.indent(6) << "Scratch = Scratch.zext(" << BitWidth << ");\n";
453  // Populate based value.
454  SS.indent(6) << "Inst = getInstBits(opcode);\n";
455
456  // Process each segment in VLI.
457  size_t Offset = 0U;
458  for (const auto &ES : VLI) {
459    unsigned NumBits = ES.BitWidth;
460    const Init *Val = ES.Value;
461    // If it's a StringInit or DagInit, it's a reference to an operand
462    // or part of an operand.
463    if (isa<StringInit>(Val) || isa<DagInit>(Val)) {
464      StringRef OperandName;
465      unsigned LoBit = 0U;
466      if (const auto *SV = dyn_cast<StringInit>(Val)) {
467        OperandName = SV->getValue();
468      } else {
469        // Normalized: (slice <operand name>, <high bit>, <low bit>)
470        const auto *DV = cast<DagInit>(Val);
471        OperandName = cast<StringInit>(DV->getArg(0))->getValue();
472        LoBit = static_cast<unsigned>(cast<IntInit>(DV->getArg(2))->getValue());
473      }
474
475      auto OpIdx = CGI.Operands.ParseOperandName(OperandName);
476      unsigned FlatOpIdx = CGI.Operands.getFlattenedOperandNumber(OpIdx);
477      StringRef CustomEncoder =
478          CGI.Operands[OpIdx.first].EncoderMethodNames[OpIdx.second];
479      if (ES.CustomEncoder.size())
480        CustomEncoder = ES.CustomEncoder;
481
482      SS.indent(6) << "Scratch.clearAllBits();\n";
483      SS.indent(6) << "// op: " << OperandName.drop_front(1) << "\n";
484      if (CustomEncoder.empty())
485        SS.indent(6) << "getMachineOpValue(MI, MI.getOperand("
486                     << utostr(FlatOpIdx) << ")";
487      else
488        SS.indent(6) << CustomEncoder << "(MI, /*OpIdx=*/" << utostr(FlatOpIdx);
489
490      SS << ", /*Pos=*/" << utostr(Offset) << ", Scratch, Fixups, STI);\n";
491
492      SS.indent(6) << "Inst.insertBits("
493                   << "Scratch.extractBits(" << utostr(NumBits) << ", "
494                   << utostr(LoBit) << ")"
495                   << ", " << Offset << ");\n";
496    }
497    Offset += NumBits;
498  }
499
500  StringRef PostEmitter = R->getValueAsString("PostEncoderMethod");
501  if (!PostEmitter.empty())
502    SS.indent(6) << "Inst = " << PostEmitter << "(MI, Inst, STI);\n";
503
504  return Case;
505}
506
507namespace llvm {
508
509void emitVarLenCodeEmitter(RecordKeeper &R, raw_ostream &OS) {
510  VarLenCodeEmitterGen(R).run(OS);
511}
512
513} // end namespace llvm
514