1//===--- LoopUnrolling.cpp - Unroll loops -----------------------*- C++ -*-===// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8/// 9/// This file contains functions which are used to decide if a loop worth to be 10/// unrolled. Moreover, these functions manages the stack of loop which is 11/// tracked by the ProgramState. 12/// 13//===----------------------------------------------------------------------===// 14 15#include "clang/ASTMatchers/ASTMatchers.h" 16#include "clang/ASTMatchers/ASTMatchFinder.h" 17#include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h" 18#include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h" 19#include "clang/StaticAnalyzer/Core/PathSensitive/LoopUnrolling.h" 20 21using namespace clang; 22using namespace ento; 23using namespace clang::ast_matchers; 24 25static const int MAXIMUM_STEP_UNROLLED = 128; 26 27struct LoopState { 28private: 29 enum Kind { Normal, Unrolled } K; 30 const Stmt *LoopStmt; 31 const LocationContext *LCtx; 32 unsigned maxStep; 33 LoopState(Kind InK, const Stmt *S, const LocationContext *L, unsigned N) 34 : K(InK), LoopStmt(S), LCtx(L), maxStep(N) {} 35 36public: 37 static LoopState getNormal(const Stmt *S, const LocationContext *L, 38 unsigned N) { 39 return LoopState(Normal, S, L, N); 40 } 41 static LoopState getUnrolled(const Stmt *S, const LocationContext *L, 42 unsigned N) { 43 return LoopState(Unrolled, S, L, N); 44 } 45 bool isUnrolled() const { return K == Unrolled; } 46 unsigned getMaxStep() const { return maxStep; } 47 const Stmt *getLoopStmt() const { return LoopStmt; } 48 const LocationContext *getLocationContext() const { return LCtx; } 49 bool operator==(const LoopState &X) const { 50 return K == X.K && LoopStmt == X.LoopStmt; 51 } 52 void Profile(llvm::FoldingSetNodeID &ID) const { 53 ID.AddInteger(K); 54 ID.AddPointer(LoopStmt); 55 ID.AddPointer(LCtx); 56 ID.AddInteger(maxStep); 57 } 58}; 59 60// The tracked stack of loops. The stack indicates that which loops the 61// simulated element contained by. The loops are marked depending if we decided 62// to unroll them. 63// TODO: The loop stack should not need to be in the program state since it is 64// lexical in nature. Instead, the stack of loops should be tracked in the 65// LocationContext. 66REGISTER_LIST_WITH_PROGRAMSTATE(LoopStack, LoopState) 67 68namespace clang { 69namespace ento { 70 71static bool isLoopStmt(const Stmt *S) { 72 return S && (isa<ForStmt>(S) || isa<WhileStmt>(S) || isa<DoStmt>(S)); 73} 74 75ProgramStateRef processLoopEnd(const Stmt *LoopStmt, ProgramStateRef State) { 76 auto LS = State->get<LoopStack>(); 77 if (!LS.isEmpty() && LS.getHead().getLoopStmt() == LoopStmt) 78 State = State->set<LoopStack>(LS.getTail()); 79 return State; 80} 81 82static internal::Matcher<Stmt> simpleCondition(StringRef BindName) { 83 return binaryOperator(anyOf(hasOperatorName("<"), hasOperatorName(">"), 84 hasOperatorName("<="), hasOperatorName(">="), 85 hasOperatorName("!=")), 86 hasEitherOperand(ignoringParenImpCasts(declRefExpr( 87 to(varDecl(hasType(isInteger())).bind(BindName))))), 88 hasEitherOperand(ignoringParenImpCasts( 89 integerLiteral().bind("boundNum")))) 90 .bind("conditionOperator"); 91} 92 93static internal::Matcher<Stmt> 94changeIntBoundNode(internal::Matcher<Decl> VarNodeMatcher) { 95 return anyOf( 96 unaryOperator(anyOf(hasOperatorName("--"), hasOperatorName("++")), 97 hasUnaryOperand(ignoringParenImpCasts( 98 declRefExpr(to(varDecl(VarNodeMatcher)))))), 99 binaryOperator(isAssignmentOperator(), 100 hasLHS(ignoringParenImpCasts( 101 declRefExpr(to(varDecl(VarNodeMatcher))))))); 102} 103 104static internal::Matcher<Stmt> 105callByRef(internal::Matcher<Decl> VarNodeMatcher) { 106 return callExpr(forEachArgumentWithParam( 107 declRefExpr(to(varDecl(VarNodeMatcher))), 108 parmVarDecl(hasType(references(qualType(unless(isConstQualified()))))))); 109} 110 111static internal::Matcher<Stmt> 112assignedToRef(internal::Matcher<Decl> VarNodeMatcher) { 113 return declStmt(hasDescendant(varDecl( 114 allOf(hasType(referenceType()), 115 hasInitializer(anyOf( 116 initListExpr(has(declRefExpr(to(varDecl(VarNodeMatcher))))), 117 declRefExpr(to(varDecl(VarNodeMatcher))))))))); 118} 119 120static internal::Matcher<Stmt> 121getAddrTo(internal::Matcher<Decl> VarNodeMatcher) { 122 return unaryOperator( 123 hasOperatorName("&"), 124 hasUnaryOperand(declRefExpr(hasDeclaration(VarNodeMatcher)))); 125} 126 127static internal::Matcher<Stmt> hasSuspiciousStmt(StringRef NodeName) { 128 return hasDescendant(stmt( 129 anyOf(gotoStmt(), switchStmt(), returnStmt(), 130 // Escaping and not known mutation of the loop counter is handled 131 // by exclusion of assigning and address-of operators and 132 // pass-by-ref function calls on the loop counter from the body. 133 changeIntBoundNode(equalsBoundNode(std::string(NodeName))), 134 callByRef(equalsBoundNode(std::string(NodeName))), 135 getAddrTo(equalsBoundNode(std::string(NodeName))), 136 assignedToRef(equalsBoundNode(std::string(NodeName)))))); 137} 138 139static internal::Matcher<Stmt> forLoopMatcher() { 140 return forStmt( 141 hasCondition(simpleCondition("initVarName")), 142 // Initialization should match the form: 'int i = 6' or 'i = 42'. 143 hasLoopInit( 144 anyOf(declStmt(hasSingleDecl( 145 varDecl(allOf(hasInitializer(ignoringParenImpCasts( 146 integerLiteral().bind("initNum"))), 147 equalsBoundNode("initVarName"))))), 148 binaryOperator(hasLHS(declRefExpr(to(varDecl( 149 equalsBoundNode("initVarName"))))), 150 hasRHS(ignoringParenImpCasts( 151 integerLiteral().bind("initNum")))))), 152 // Incrementation should be a simple increment or decrement 153 // operator call. 154 hasIncrement(unaryOperator( 155 anyOf(hasOperatorName("++"), hasOperatorName("--")), 156 hasUnaryOperand(declRefExpr( 157 to(varDecl(allOf(equalsBoundNode("initVarName"), 158 hasType(isInteger())))))))), 159 unless(hasBody(hasSuspiciousStmt("initVarName")))).bind("forLoop"); 160} 161 162static bool isPossiblyEscaped(const VarDecl *VD, ExplodedNode *N) { 163 // Global variables assumed as escaped variables. 164 if (VD->hasGlobalStorage()) 165 return true; 166 167 const bool isParm = isa<ParmVarDecl>(VD); 168 // Reference parameters are assumed as escaped variables. 169 if (isParm && VD->getType()->isReferenceType()) 170 return true; 171 172 while (!N->pred_empty()) { 173 // FIXME: getStmtForDiagnostics() does nasty things in order to provide 174 // a valid statement for body farms, do we need this behavior here? 175 const Stmt *S = N->getStmtForDiagnostics(); 176 if (!S) { 177 N = N->getFirstPred(); 178 continue; 179 } 180 181 if (const DeclStmt *DS = dyn_cast<DeclStmt>(S)) { 182 for (const Decl *D : DS->decls()) { 183 // Once we reach the declaration of the VD we can return. 184 if (D->getCanonicalDecl() == VD) 185 return false; 186 } 187 } 188 // Check the usage of the pass-by-ref function calls and adress-of operator 189 // on VD and reference initialized by VD. 190 ASTContext &ASTCtx = 191 N->getLocationContext()->getAnalysisDeclContext()->getASTContext(); 192 auto Match = 193 match(stmt(anyOf(callByRef(equalsNode(VD)), getAddrTo(equalsNode(VD)), 194 assignedToRef(equalsNode(VD)))), 195 *S, ASTCtx); 196 if (!Match.empty()) 197 return true; 198 199 N = N->getFirstPred(); 200 } 201 202 // Parameter declaration will not be found. 203 if (isParm) 204 return false; 205 206 llvm_unreachable("Reached root without finding the declaration of VD"); 207} 208 209bool shouldCompletelyUnroll(const Stmt *LoopStmt, ASTContext &ASTCtx, 210 ExplodedNode *Pred, unsigned &maxStep) { 211 212 if (!isLoopStmt(LoopStmt)) 213 return false; 214 215 // TODO: Match the cases where the bound is not a concrete literal but an 216 // integer with known value 217 auto Matches = match(forLoopMatcher(), *LoopStmt, ASTCtx); 218 if (Matches.empty()) 219 return false; 220 221 auto CounterVar = Matches[0].getNodeAs<VarDecl>("initVarName"); 222 llvm::APInt BoundNum = 223 Matches[0].getNodeAs<IntegerLiteral>("boundNum")->getValue(); 224 llvm::APInt InitNum = 225 Matches[0].getNodeAs<IntegerLiteral>("initNum")->getValue(); 226 auto CondOp = Matches[0].getNodeAs<BinaryOperator>("conditionOperator"); 227 if (InitNum.getBitWidth() != BoundNum.getBitWidth()) { 228 InitNum = InitNum.zextOrSelf(BoundNum.getBitWidth()); 229 BoundNum = BoundNum.zextOrSelf(InitNum.getBitWidth()); 230 } 231 232 if (CondOp->getOpcode() == BO_GE || CondOp->getOpcode() == BO_LE) 233 maxStep = (BoundNum - InitNum + 1).abs().getZExtValue(); 234 else 235 maxStep = (BoundNum - InitNum).abs().getZExtValue(); 236 237 // Check if the counter of the loop is not escaped before. 238 return !isPossiblyEscaped(CounterVar->getCanonicalDecl(), Pred); 239} 240 241bool madeNewBranch(ExplodedNode *N, const Stmt *LoopStmt) { 242 const Stmt *S = nullptr; 243 while (!N->pred_empty()) { 244 if (N->succ_size() > 1) 245 return true; 246 247 ProgramPoint P = N->getLocation(); 248 if (Optional<BlockEntrance> BE = P.getAs<BlockEntrance>()) 249 S = BE->getBlock()->getTerminatorStmt(); 250 251 if (S == LoopStmt) 252 return false; 253 254 N = N->getFirstPred(); 255 } 256 257 llvm_unreachable("Reached root without encountering the previous step"); 258} 259 260// updateLoopStack is called on every basic block, therefore it needs to be fast 261ProgramStateRef updateLoopStack(const Stmt *LoopStmt, ASTContext &ASTCtx, 262 ExplodedNode *Pred, unsigned maxVisitOnPath) { 263 auto State = Pred->getState(); 264 auto LCtx = Pred->getLocationContext(); 265 266 if (!isLoopStmt(LoopStmt)) 267 return State; 268 269 auto LS = State->get<LoopStack>(); 270 if (!LS.isEmpty() && LoopStmt == LS.getHead().getLoopStmt() && 271 LCtx == LS.getHead().getLocationContext()) { 272 if (LS.getHead().isUnrolled() && madeNewBranch(Pred, LoopStmt)) { 273 State = State->set<LoopStack>(LS.getTail()); 274 State = State->add<LoopStack>( 275 LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath)); 276 } 277 return State; 278 } 279 unsigned maxStep; 280 if (!shouldCompletelyUnroll(LoopStmt, ASTCtx, Pred, maxStep)) { 281 State = State->add<LoopStack>( 282 LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath)); 283 return State; 284 } 285 286 unsigned outerStep = (LS.isEmpty() ? 1 : LS.getHead().getMaxStep()); 287 288 unsigned innerMaxStep = maxStep * outerStep; 289 if (innerMaxStep > MAXIMUM_STEP_UNROLLED) 290 State = State->add<LoopStack>( 291 LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath)); 292 else 293 State = State->add<LoopStack>( 294 LoopState::getUnrolled(LoopStmt, LCtx, innerMaxStep)); 295 return State; 296} 297 298bool isUnrolledState(ProgramStateRef State) { 299 auto LS = State->get<LoopStack>(); 300 if (LS.isEmpty() || !LS.getHead().isUnrolled()) 301 return false; 302 return true; 303} 304} 305} 306