1//===--- LoopUnrolling.cpp - Unroll loops -----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// This file contains functions which are used to decide if a loop worth to be
10/// unrolled. Moreover, these functions manages the stack of loop which is
11/// tracked by the ProgramState.
12///
13//===----------------------------------------------------------------------===//
14
15#include "clang/ASTMatchers/ASTMatchers.h"
16#include "clang/ASTMatchers/ASTMatchFinder.h"
17#include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h"
18#include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h"
19#include "clang/StaticAnalyzer/Core/PathSensitive/LoopUnrolling.h"
20
21using namespace clang;
22using namespace ento;
23using namespace clang::ast_matchers;
24
25static const int MAXIMUM_STEP_UNROLLED = 128;
26
27struct LoopState {
28private:
29  enum Kind { Normal, Unrolled } K;
30  const Stmt *LoopStmt;
31  const LocationContext *LCtx;
32  unsigned maxStep;
33  LoopState(Kind InK, const Stmt *S, const LocationContext *L, unsigned N)
34      : K(InK), LoopStmt(S), LCtx(L), maxStep(N) {}
35
36public:
37  static LoopState getNormal(const Stmt *S, const LocationContext *L,
38                             unsigned N) {
39    return LoopState(Normal, S, L, N);
40  }
41  static LoopState getUnrolled(const Stmt *S, const LocationContext *L,
42                               unsigned N) {
43    return LoopState(Unrolled, S, L, N);
44  }
45  bool isUnrolled() const { return K == Unrolled; }
46  unsigned getMaxStep() const { return maxStep; }
47  const Stmt *getLoopStmt() const { return LoopStmt; }
48  const LocationContext *getLocationContext() const { return LCtx; }
49  bool operator==(const LoopState &X) const {
50    return K == X.K && LoopStmt == X.LoopStmt;
51  }
52  void Profile(llvm::FoldingSetNodeID &ID) const {
53    ID.AddInteger(K);
54    ID.AddPointer(LoopStmt);
55    ID.AddPointer(LCtx);
56    ID.AddInteger(maxStep);
57  }
58};
59
60// The tracked stack of loops. The stack indicates that which loops the
61// simulated element contained by. The loops are marked depending if we decided
62// to unroll them.
63// TODO: The loop stack should not need to be in the program state since it is
64// lexical in nature. Instead, the stack of loops should be tracked in the
65// LocationContext.
66REGISTER_LIST_WITH_PROGRAMSTATE(LoopStack, LoopState)
67
68namespace clang {
69namespace ento {
70
71static bool isLoopStmt(const Stmt *S) {
72  return S && (isa<ForStmt>(S) || isa<WhileStmt>(S) || isa<DoStmt>(S));
73}
74
75ProgramStateRef processLoopEnd(const Stmt *LoopStmt, ProgramStateRef State) {
76  auto LS = State->get<LoopStack>();
77  if (!LS.isEmpty() && LS.getHead().getLoopStmt() == LoopStmt)
78    State = State->set<LoopStack>(LS.getTail());
79  return State;
80}
81
82static internal::Matcher<Stmt> simpleCondition(StringRef BindName) {
83  return binaryOperator(anyOf(hasOperatorName("<"), hasOperatorName(">"),
84                              hasOperatorName("<="), hasOperatorName(">="),
85                              hasOperatorName("!=")),
86                        hasEitherOperand(ignoringParenImpCasts(declRefExpr(
87                            to(varDecl(hasType(isInteger())).bind(BindName))))),
88                        hasEitherOperand(ignoringParenImpCasts(
89                            integerLiteral().bind("boundNum"))))
90      .bind("conditionOperator");
91}
92
93static internal::Matcher<Stmt>
94changeIntBoundNode(internal::Matcher<Decl> VarNodeMatcher) {
95  return anyOf(
96      unaryOperator(anyOf(hasOperatorName("--"), hasOperatorName("++")),
97                    hasUnaryOperand(ignoringParenImpCasts(
98                        declRefExpr(to(varDecl(VarNodeMatcher)))))),
99      binaryOperator(isAssignmentOperator(),
100                     hasLHS(ignoringParenImpCasts(
101                         declRefExpr(to(varDecl(VarNodeMatcher)))))));
102}
103
104static internal::Matcher<Stmt>
105callByRef(internal::Matcher<Decl> VarNodeMatcher) {
106  return callExpr(forEachArgumentWithParam(
107      declRefExpr(to(varDecl(VarNodeMatcher))),
108      parmVarDecl(hasType(references(qualType(unless(isConstQualified())))))));
109}
110
111static internal::Matcher<Stmt>
112assignedToRef(internal::Matcher<Decl> VarNodeMatcher) {
113  return declStmt(hasDescendant(varDecl(
114      allOf(hasType(referenceType()),
115            hasInitializer(anyOf(
116                initListExpr(has(declRefExpr(to(varDecl(VarNodeMatcher))))),
117                declRefExpr(to(varDecl(VarNodeMatcher)))))))));
118}
119
120static internal::Matcher<Stmt>
121getAddrTo(internal::Matcher<Decl> VarNodeMatcher) {
122  return unaryOperator(
123      hasOperatorName("&"),
124      hasUnaryOperand(declRefExpr(hasDeclaration(VarNodeMatcher))));
125}
126
127static internal::Matcher<Stmt> hasSuspiciousStmt(StringRef NodeName) {
128  return hasDescendant(stmt(
129      anyOf(gotoStmt(), switchStmt(), returnStmt(),
130            // Escaping and not known mutation of the loop counter is handled
131            // by exclusion of assigning and address-of operators and
132            // pass-by-ref function calls on the loop counter from the body.
133            changeIntBoundNode(equalsBoundNode(std::string(NodeName))),
134            callByRef(equalsBoundNode(std::string(NodeName))),
135            getAddrTo(equalsBoundNode(std::string(NodeName))),
136            assignedToRef(equalsBoundNode(std::string(NodeName))))));
137}
138
139static internal::Matcher<Stmt> forLoopMatcher() {
140  return forStmt(
141             hasCondition(simpleCondition("initVarName")),
142             // Initialization should match the form: 'int i = 6' or 'i = 42'.
143             hasLoopInit(
144                 anyOf(declStmt(hasSingleDecl(
145                           varDecl(allOf(hasInitializer(ignoringParenImpCasts(
146                                             integerLiteral().bind("initNum"))),
147                                         equalsBoundNode("initVarName"))))),
148                       binaryOperator(hasLHS(declRefExpr(to(varDecl(
149                                          equalsBoundNode("initVarName"))))),
150                                      hasRHS(ignoringParenImpCasts(
151                                          integerLiteral().bind("initNum")))))),
152             // Incrementation should be a simple increment or decrement
153             // operator call.
154             hasIncrement(unaryOperator(
155                 anyOf(hasOperatorName("++"), hasOperatorName("--")),
156                 hasUnaryOperand(declRefExpr(
157                     to(varDecl(allOf(equalsBoundNode("initVarName"),
158                                      hasType(isInteger())))))))),
159             unless(hasBody(hasSuspiciousStmt("initVarName")))).bind("forLoop");
160}
161
162static bool isPossiblyEscaped(const VarDecl *VD, ExplodedNode *N) {
163  // Global variables assumed as escaped variables.
164  if (VD->hasGlobalStorage())
165    return true;
166
167  const bool isParm = isa<ParmVarDecl>(VD);
168  // Reference parameters are assumed as escaped variables.
169  if (isParm && VD->getType()->isReferenceType())
170    return true;
171
172  while (!N->pred_empty()) {
173    // FIXME: getStmtForDiagnostics() does nasty things in order to provide
174    // a valid statement for body farms, do we need this behavior here?
175    const Stmt *S = N->getStmtForDiagnostics();
176    if (!S) {
177      N = N->getFirstPred();
178      continue;
179    }
180
181    if (const DeclStmt *DS = dyn_cast<DeclStmt>(S)) {
182      for (const Decl *D : DS->decls()) {
183        // Once we reach the declaration of the VD we can return.
184        if (D->getCanonicalDecl() == VD)
185          return false;
186      }
187    }
188    // Check the usage of the pass-by-ref function calls and adress-of operator
189    // on VD and reference initialized by VD.
190    ASTContext &ASTCtx =
191        N->getLocationContext()->getAnalysisDeclContext()->getASTContext();
192    auto Match =
193        match(stmt(anyOf(callByRef(equalsNode(VD)), getAddrTo(equalsNode(VD)),
194                         assignedToRef(equalsNode(VD)))),
195              *S, ASTCtx);
196    if (!Match.empty())
197      return true;
198
199    N = N->getFirstPred();
200  }
201
202  // Parameter declaration will not be found.
203  if (isParm)
204    return false;
205
206  llvm_unreachable("Reached root without finding the declaration of VD");
207}
208
209bool shouldCompletelyUnroll(const Stmt *LoopStmt, ASTContext &ASTCtx,
210                            ExplodedNode *Pred, unsigned &maxStep) {
211
212  if (!isLoopStmt(LoopStmt))
213    return false;
214
215  // TODO: Match the cases where the bound is not a concrete literal but an
216  // integer with known value
217  auto Matches = match(forLoopMatcher(), *LoopStmt, ASTCtx);
218  if (Matches.empty())
219    return false;
220
221  auto CounterVar = Matches[0].getNodeAs<VarDecl>("initVarName");
222  llvm::APInt BoundNum =
223      Matches[0].getNodeAs<IntegerLiteral>("boundNum")->getValue();
224  llvm::APInt InitNum =
225      Matches[0].getNodeAs<IntegerLiteral>("initNum")->getValue();
226  auto CondOp = Matches[0].getNodeAs<BinaryOperator>("conditionOperator");
227  if (InitNum.getBitWidth() != BoundNum.getBitWidth()) {
228    InitNum = InitNum.zextOrSelf(BoundNum.getBitWidth());
229    BoundNum = BoundNum.zextOrSelf(InitNum.getBitWidth());
230  }
231
232  if (CondOp->getOpcode() == BO_GE || CondOp->getOpcode() == BO_LE)
233    maxStep = (BoundNum - InitNum + 1).abs().getZExtValue();
234  else
235    maxStep = (BoundNum - InitNum).abs().getZExtValue();
236
237  // Check if the counter of the loop is not escaped before.
238  return !isPossiblyEscaped(CounterVar->getCanonicalDecl(), Pred);
239}
240
241bool madeNewBranch(ExplodedNode *N, const Stmt *LoopStmt) {
242  const Stmt *S = nullptr;
243  while (!N->pred_empty()) {
244    if (N->succ_size() > 1)
245      return true;
246
247    ProgramPoint P = N->getLocation();
248    if (Optional<BlockEntrance> BE = P.getAs<BlockEntrance>())
249      S = BE->getBlock()->getTerminatorStmt();
250
251    if (S == LoopStmt)
252      return false;
253
254    N = N->getFirstPred();
255  }
256
257  llvm_unreachable("Reached root without encountering the previous step");
258}
259
260// updateLoopStack is called on every basic block, therefore it needs to be fast
261ProgramStateRef updateLoopStack(const Stmt *LoopStmt, ASTContext &ASTCtx,
262                                ExplodedNode *Pred, unsigned maxVisitOnPath) {
263  auto State = Pred->getState();
264  auto LCtx = Pred->getLocationContext();
265
266  if (!isLoopStmt(LoopStmt))
267    return State;
268
269  auto LS = State->get<LoopStack>();
270  if (!LS.isEmpty() && LoopStmt == LS.getHead().getLoopStmt() &&
271      LCtx == LS.getHead().getLocationContext()) {
272    if (LS.getHead().isUnrolled() && madeNewBranch(Pred, LoopStmt)) {
273      State = State->set<LoopStack>(LS.getTail());
274      State = State->add<LoopStack>(
275          LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath));
276    }
277    return State;
278  }
279  unsigned maxStep;
280  if (!shouldCompletelyUnroll(LoopStmt, ASTCtx, Pred, maxStep)) {
281    State = State->add<LoopStack>(
282        LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath));
283    return State;
284  }
285
286  unsigned outerStep = (LS.isEmpty() ? 1 : LS.getHead().getMaxStep());
287
288  unsigned innerMaxStep = maxStep * outerStep;
289  if (innerMaxStep > MAXIMUM_STEP_UNROLLED)
290    State = State->add<LoopStack>(
291        LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath));
292  else
293    State = State->add<LoopStack>(
294        LoopState::getUnrolled(LoopStmt, LCtx, innerMaxStep));
295  return State;
296}
297
298bool isUnrolledState(ProgramStateRef State) {
299  auto LS = State->get<LoopStack>();
300  if (LS.isEmpty() || !LS.getHead().isUnrolled())
301    return false;
302  return true;
303}
304}
305}
306