1//===- MatrixUtils.h - Utilities to lower matrix intrinsics -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Utilities for generating tiled loops for matrix operations.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef LLVM_TRANSFORMS_UTILS_MATRIXUTILS_H
14#define LLVM_TRANSFORMS_UTILS_MATRIXUTILS_H
15
16#include "llvm/ADT/StringRef.h"
17
18namespace llvm {
19class DomTreeUpdater;
20class BasicBlock;
21class Value;
22class Loop;
23class LoopInfo;
24class IRBuilderBase;
25
26/// A helper struct to create IR loop nests for tiling in IR of the following
27/// form:
28///   for ColumnLoop.Index = 0..NumColumns
29///     for RowLoop.Index = 0..NumRows
30///       for KLoop.Index = 0..NumInner
31struct TileInfo {
32  /// Number of rows of the matrix.
33  unsigned NumRows;
34
35  /// Number of columns of the matrix.
36  unsigned NumColumns;
37
38  /// Number of columns of the first matrix of a multiply /
39  /// number of rows of the second matrix of a multiply.
40  unsigned NumInner;
41
42  /// Number of rows/columns in a tile.
43  unsigned TileSize = -1;
44
45  /// Properties of a single loop used when generating the tiled loop nest.
46  struct MatrixLoop {
47    /// The index updated on every iteration.
48    Value *Index = nullptr;
49    /// The header and latch of the loop.
50    BasicBlock *Header = nullptr;
51    BasicBlock *Latch = nullptr;
52  };
53
54  /// The loop iterating on the rows.
55  MatrixLoop RowLoop;
56  /// The loop iterating on the columns.
57  MatrixLoop ColumnLoop;
58  /// The loop iterating on k (inner dimension).
59  MatrixLoop KLoop;
60
61  TileInfo(unsigned NumRows, unsigned NumColumns, unsigned NumInner,
62           unsigned TileSize)
63      : NumRows(NumRows), NumColumns(NumColumns), NumInner(NumInner),
64        TileSize(TileSize) {}
65
66  /// Creates an IR loop nests for tiling of the form below. Returns the block
67  /// for the inner loop body and sets {Column,Row,Inner}LoopHeader/Latch
68  /// fields.
69  ///
70  /// for ColumnLoop.Index = 0..NumColumns
71  ///   for RowLoop.Index = 0..NumRows
72  ///     for InnerLoop.Index = 0..NumInner
73  BasicBlock *CreateTiledLoops(BasicBlock *Start, BasicBlock *End,
74                               IRBuilderBase &B, DomTreeUpdater &DTU,
75                               LoopInfo &LI);
76
77private:
78  /// Creates a new loop with header, body and latch blocks that iterates from
79  /// [0, Bound). Updates \p Preheader to branch to the new header and uses \p
80  /// Exit as exit block.  Adds the new loop blocks to \L and applies dominator
81  /// tree updates to \p DTU.
82  static BasicBlock *CreateLoop(BasicBlock *Preheader, BasicBlock *Exit,
83                                Value *Bound, Value *Step, StringRef Name,
84                                IRBuilderBase &B, DomTreeUpdater &DTU, Loop *L,
85                                LoopInfo &LI);
86};
87} // namespace llvm
88
89#endif
90