1351278Sdim//===-- VPlanPredicator.cpp -------------------------------------*- C++ -*-===//
2351278Sdim//
3351278Sdim// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4351278Sdim// See https://llvm.org/LICENSE.txt for license information.
5351278Sdim// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6351278Sdim//
7351278Sdim//===----------------------------------------------------------------------===//
8351278Sdim///
9351278Sdim/// \file
10351278Sdim/// This file implements the VPlanPredicator class which contains the public
11351278Sdim/// interfaces to predicate and linearize the VPlan region.
12351278Sdim///
13351278Sdim//===----------------------------------------------------------------------===//
14351278Sdim
15351278Sdim#include "VPlanPredicator.h"
16351278Sdim#include "VPlan.h"
17351278Sdim#include "llvm/ADT/DepthFirstIterator.h"
18351278Sdim#include "llvm/ADT/GraphTraits.h"
19351278Sdim#include "llvm/ADT/PostOrderIterator.h"
20351278Sdim#include "llvm/Support/Debug.h"
21351278Sdim#include "llvm/Support/raw_ostream.h"
22351278Sdim
23351278Sdim#define DEBUG_TYPE "VPlanPredicator"
24351278Sdim
25351278Sdimusing namespace llvm;
26351278Sdim
27351278Sdim// Generate VPInstructions at the beginning of CurrBB that calculate the
28351278Sdim// predicate being propagated from PredBB to CurrBB depending on the edge type
29351278Sdim// between them. For example if:
30351278Sdim//  i.  PredBB is controlled by predicate %BP, and
31351278Sdim//  ii. The edge PredBB->CurrBB is the false edge, controlled by the condition
32351278Sdim//  bit value %CBV then this function will generate the following two
33351278Sdim//  VPInstructions at the start of CurrBB:
34351278Sdim//   %IntermediateVal = not %CBV
35351278Sdim//   %FinalVal        = and %BP %IntermediateVal
36351278Sdim// It returns %FinalVal.
37351278SdimVPValue *VPlanPredicator::getOrCreateNotPredicate(VPBasicBlock *PredBB,
38351278Sdim                                                  VPBasicBlock *CurrBB) {
39351278Sdim  VPValue *CBV = PredBB->getCondBit();
40351278Sdim
41351278Sdim  // Set the intermediate value - this is either 'CBV', or 'not CBV'
42351278Sdim  // depending on the edge type.
43351278Sdim  EdgeType ET = getEdgeTypeBetween(PredBB, CurrBB);
44351278Sdim  VPValue *IntermediateVal = nullptr;
45351278Sdim  switch (ET) {
46351278Sdim  case EdgeType::TRUE_EDGE:
47351278Sdim    // CurrBB is the true successor of PredBB - nothing to do here.
48351278Sdim    IntermediateVal = CBV;
49351278Sdim    break;
50351278Sdim
51351278Sdim  case EdgeType::FALSE_EDGE:
52351278Sdim    // CurrBB is the False successor of PredBB - compute not of CBV.
53351278Sdim    IntermediateVal = Builder.createNot(CBV);
54351278Sdim    break;
55351278Sdim  }
56351278Sdim
57351278Sdim  // Now AND intermediate value with PredBB's block predicate if it has one.
58351278Sdim  VPValue *BP = PredBB->getPredicate();
59351278Sdim  if (BP)
60351278Sdim    return Builder.createAnd(BP, IntermediateVal);
61351278Sdim  else
62351278Sdim    return IntermediateVal;
63351278Sdim}
64351278Sdim
65351278Sdim// Generate a tree of ORs for all IncomingPredicates in  WorkList.
66351278Sdim// Note: This function destroys the original Worklist.
67351278Sdim//
68351278Sdim// P1 P2 P3 P4 P5
69351278Sdim//  \ /   \ /  /
70351278Sdim//  OR1   OR2 /
71351278Sdim//    \    | /
72351278Sdim//     \   +/-+
73351278Sdim//      \  /  |
74351278Sdim//       OR3  |
75351278Sdim//         \  |
76351278Sdim//          OR4 <- Returns this
77351278Sdim//           |
78351278Sdim//
79351278Sdim// The algorithm uses a worklist of predicates as its main data structure.
80351278Sdim// We pop a pair of values from the front (e.g. P1 and P2), generate an OR
81351278Sdim// (in this example OR1), and push it back. In this example the worklist
82351278Sdim// contains {P3, P4, P5, OR1}.
83351278Sdim// The process iterates until we have only one element in the Worklist (OR4).
84351278Sdim// The last element is the root predicate which is returned.
85351278SdimVPValue *VPlanPredicator::genPredicateTree(std::list<VPValue *> &Worklist) {
86351278Sdim  if (Worklist.empty())
87351278Sdim    return nullptr;
88351278Sdim
89351278Sdim  // The worklist initially contains all the leaf nodes. Initialize the tree
90351278Sdim  // using them.
91351278Sdim  while (Worklist.size() >= 2) {
92351278Sdim    // Pop a pair of values from the front.
93351278Sdim    VPValue *LHS = Worklist.front();
94351278Sdim    Worklist.pop_front();
95351278Sdim    VPValue *RHS = Worklist.front();
96351278Sdim    Worklist.pop_front();
97351278Sdim
98351278Sdim    // Create an OR of these values.
99351278Sdim    VPValue *Or = Builder.createOr(LHS, RHS);
100351278Sdim
101351278Sdim    // Push OR to the back of the worklist.
102351278Sdim    Worklist.push_back(Or);
103351278Sdim  }
104351278Sdim
105351278Sdim  assert(Worklist.size() == 1 && "Expected 1 item in worklist");
106351278Sdim
107351278Sdim  // The root is the last node in the worklist.
108351278Sdim  VPValue *Root = Worklist.front();
109351278Sdim
110351278Sdim  // This root needs to replace the existing block predicate. This is done in
111351278Sdim  // the caller function.
112351278Sdim  return Root;
113351278Sdim}
114351278Sdim
115351278Sdim// Return whether the edge FromBlock -> ToBlock is a TRUE_EDGE or FALSE_EDGE
116351278SdimVPlanPredicator::EdgeType
117351278SdimVPlanPredicator::getEdgeTypeBetween(VPBlockBase *FromBlock,
118351278Sdim                                    VPBlockBase *ToBlock) {
119351278Sdim  unsigned Count = 0;
120351278Sdim  for (VPBlockBase *SuccBlock : FromBlock->getSuccessors()) {
121351278Sdim    if (SuccBlock == ToBlock) {
122351278Sdim      assert(Count < 2 && "Switch not supported currently");
123351278Sdim      return (Count == 0) ? EdgeType::TRUE_EDGE : EdgeType::FALSE_EDGE;
124351278Sdim    }
125351278Sdim    Count++;
126351278Sdim  }
127351278Sdim
128351278Sdim  llvm_unreachable("Broken getEdgeTypeBetween");
129351278Sdim}
130351278Sdim
131351278Sdim// Generate all predicates needed for CurrBlock by going through its immediate
132351278Sdim// predecessor blocks.
133351278Sdimvoid VPlanPredicator::createOrPropagatePredicates(VPBlockBase *CurrBlock,
134351278Sdim                                                  VPRegionBlock *Region) {
135351278Sdim  // Blocks that dominate region exit inherit the predicate from the region.
136351278Sdim  // Return after setting the predicate.
137351278Sdim  if (VPDomTree.dominates(CurrBlock, Region->getExit())) {
138351278Sdim    VPValue *RegionBP = Region->getPredicate();
139351278Sdim    CurrBlock->setPredicate(RegionBP);
140351278Sdim    return;
141351278Sdim  }
142351278Sdim
143351278Sdim  // Collect all incoming predicates in a worklist.
144351278Sdim  std::list<VPValue *> IncomingPredicates;
145351278Sdim
146351278Sdim  // Set the builder's insertion point to the top of the current BB
147351278Sdim  VPBasicBlock *CurrBB = cast<VPBasicBlock>(CurrBlock->getEntryBasicBlock());
148351278Sdim  Builder.setInsertPoint(CurrBB, CurrBB->begin());
149351278Sdim
150351278Sdim  // For each predecessor, generate the VPInstructions required for
151351278Sdim  // computing 'BP AND (not) CBV" at the top of CurrBB.
152351278Sdim  // Collect the outcome of this calculation for all predecessors
153351278Sdim  // into IncomingPredicates.
154351278Sdim  for (VPBlockBase *PredBlock : CurrBlock->getPredecessors()) {
155351278Sdim    // Skip back-edges
156351278Sdim    if (VPBlockUtils::isBackEdge(PredBlock, CurrBlock, VPLI))
157351278Sdim      continue;
158351278Sdim
159351278Sdim    VPValue *IncomingPredicate = nullptr;
160351278Sdim    unsigned NumPredSuccsNoBE =
161351278Sdim        VPBlockUtils::countSuccessorsNoBE(PredBlock, VPLI);
162351278Sdim
163351278Sdim    // If there is an unconditional branch to the currBB, then we don't create
164351278Sdim    // edge predicates. We use the predecessor's block predicate instead.
165351278Sdim    if (NumPredSuccsNoBE == 1)
166351278Sdim      IncomingPredicate = PredBlock->getPredicate();
167351278Sdim    else if (NumPredSuccsNoBE == 2) {
168351278Sdim      // Emit recipes into CurrBlock if required
169351278Sdim      assert(isa<VPBasicBlock>(PredBlock) && "Only BBs have multiple exits");
170351278Sdim      IncomingPredicate =
171351278Sdim          getOrCreateNotPredicate(cast<VPBasicBlock>(PredBlock), CurrBB);
172351278Sdim    } else
173351278Sdim      llvm_unreachable("FIXME: switch statement ?");
174351278Sdim
175351278Sdim    if (IncomingPredicate)
176351278Sdim      IncomingPredicates.push_back(IncomingPredicate);
177351278Sdim  }
178351278Sdim
179351278Sdim  // Logically OR all incoming predicates by building the Predicate Tree.
180351278Sdim  VPValue *Predicate = genPredicateTree(IncomingPredicates);
181351278Sdim
182351278Sdim  // Now update the block's predicate with the new one.
183351278Sdim  CurrBlock->setPredicate(Predicate);
184351278Sdim}
185351278Sdim
186351278Sdim// Generate all predicates needed for Region.
187351278Sdimvoid VPlanPredicator::predicateRegionRec(VPRegionBlock *Region) {
188351278Sdim  VPBasicBlock *EntryBlock = cast<VPBasicBlock>(Region->getEntry());
189351278Sdim  ReversePostOrderTraversal<VPBlockBase *> RPOT(EntryBlock);
190351278Sdim
191351278Sdim  // Generate edge predicates and append them to the block predicate. RPO is
192351278Sdim  // necessary since the predecessor blocks' block predicate needs to be set
193351278Sdim  // before the current block's block predicate can be computed.
194351278Sdim  for (VPBlockBase *Block : make_range(RPOT.begin(), RPOT.end())) {
195351278Sdim    // TODO: Handle nested regions once we start generating the same.
196351278Sdim    assert(!isa<VPRegionBlock>(Block) && "Nested region not expected");
197351278Sdim    createOrPropagatePredicates(Block, Region);
198351278Sdim  }
199351278Sdim}
200351278Sdim
201351278Sdim// Linearize the CFG within Region.
202351278Sdim// TODO: Predication and linearization need RPOT for every region.
203351278Sdim// This traversal is expensive. Since predication is not adding new
204351278Sdim// blocks, we should be able to compute RPOT once in predication and
205351278Sdim// reuse it here. This becomes even more important once we have nested
206351278Sdim// regions.
207351278Sdimvoid VPlanPredicator::linearizeRegionRec(VPRegionBlock *Region) {
208351278Sdim  ReversePostOrderTraversal<VPBlockBase *> RPOT(Region->getEntry());
209351278Sdim  VPBlockBase *PrevBlock = nullptr;
210351278Sdim
211351278Sdim  for (VPBlockBase *CurrBlock : make_range(RPOT.begin(), RPOT.end())) {
212351278Sdim    // TODO: Handle nested regions once we start generating the same.
213351278Sdim    assert(!isa<VPRegionBlock>(CurrBlock) && "Nested region not expected");
214351278Sdim
215351278Sdim    // Linearize control flow by adding an unconditional edge between PrevBlock
216351278Sdim    // and CurrBlock skipping loop headers and latches to keep intact loop
217351278Sdim    // header predecessors and loop latch successors.
218351278Sdim    if (PrevBlock && !VPLI->isLoopHeader(CurrBlock) &&
219351278Sdim        !VPBlockUtils::blockIsLoopLatch(PrevBlock, VPLI)) {
220351278Sdim
221351278Sdim      LLVM_DEBUG(dbgs() << "Linearizing: " << PrevBlock->getName() << "->"
222351278Sdim                        << CurrBlock->getName() << "\n");
223351278Sdim
224351278Sdim      PrevBlock->clearSuccessors();
225351278Sdim      CurrBlock->clearPredecessors();
226351278Sdim      VPBlockUtils::connectBlocks(PrevBlock, CurrBlock);
227351278Sdim    }
228351278Sdim
229351278Sdim    PrevBlock = CurrBlock;
230351278Sdim  }
231351278Sdim}
232351278Sdim
233351278Sdim// Entry point. The driver function for the predicator.
234351278Sdimvoid VPlanPredicator::predicate(void) {
235351278Sdim  // Predicate the blocks within Region.
236351278Sdim  predicateRegionRec(cast<VPRegionBlock>(Plan.getEntry()));
237351278Sdim
238351278Sdim  // Linearlize the blocks with Region.
239351278Sdim  linearizeRegionRec(cast<VPRegionBlock>(Plan.getEntry()));
240351278Sdim}
241351278Sdim
242351278SdimVPlanPredicator::VPlanPredicator(VPlan &Plan)
243351278Sdim    : Plan(Plan), VPLI(&(Plan.getVPLoopInfo())) {
244351278Sdim  // FIXME: Predicator is currently computing the dominator information for the
245351278Sdim  // top region. Once we start storing dominator information in a VPRegionBlock,
246351278Sdim  // we can avoid this recalculation.
247351278Sdim  VPDomTree.recalculate(*(cast<VPRegionBlock>(Plan.getEntry())));
248351278Sdim}
249