1//===--- TransAutoreleasePool.cpp - Transformations to ARC mode -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// rewriteAutoreleasePool:
10//
11// Calls to NSAutoreleasePools will be rewritten as an @autorelease scope.
12//
13//  NSAutoreleasePool *pool = [[NSAutoreleasePool alloc] init];
14//  ...
15//  [pool release];
16// ---->
17//  @autorelease {
18//  ...
19//  }
20//
21// An NSAutoreleasePool will not be touched if:
22// - There is not a corresponding -release/-drain in the same scope
23// - Not all references of the NSAutoreleasePool variable can be removed
24// - There is a variable that is declared inside the intended @autorelease scope
25//   which is also used outside it.
26//
27//===----------------------------------------------------------------------===//
28
29#include "Transforms.h"
30#include "Internals.h"
31#include "clang/AST/ASTContext.h"
32#include "clang/Basic/SourceManager.h"
33#include "clang/Sema/SemaDiagnostic.h"
34#include <map>
35
36using namespace clang;
37using namespace arcmt;
38using namespace trans;
39
40namespace {
41
42class ReleaseCollector : public RecursiveASTVisitor<ReleaseCollector> {
43  Decl *Dcl;
44  SmallVectorImpl<ObjCMessageExpr *> &Releases;
45
46public:
47  ReleaseCollector(Decl *D, SmallVectorImpl<ObjCMessageExpr *> &releases)
48    : Dcl(D), Releases(releases) { }
49
50  bool VisitObjCMessageExpr(ObjCMessageExpr *E) {
51    if (!E->isInstanceMessage())
52      return true;
53    if (E->getMethodFamily() != OMF_release)
54      return true;
55    Expr *instance = E->getInstanceReceiver()->IgnoreParenCasts();
56    if (DeclRefExpr *DE = dyn_cast<DeclRefExpr>(instance)) {
57      if (DE->getDecl() == Dcl)
58        Releases.push_back(E);
59    }
60    return true;
61  }
62};
63
64}
65
66namespace {
67
68class AutoreleasePoolRewriter
69                         : public RecursiveASTVisitor<AutoreleasePoolRewriter> {
70public:
71  AutoreleasePoolRewriter(MigrationPass &pass)
72    : Body(nullptr), Pass(pass) {
73    PoolII = &pass.Ctx.Idents.get("NSAutoreleasePool");
74    DrainSel = pass.Ctx.Selectors.getNullarySelector(
75                                                 &pass.Ctx.Idents.get("drain"));
76  }
77
78  void transformBody(Stmt *body, Decl *ParentD) {
79    Body = body;
80    TraverseStmt(body);
81  }
82
83  ~AutoreleasePoolRewriter() {
84    SmallVector<VarDecl *, 8> VarsToHandle;
85
86    for (std::map<VarDecl *, PoolVarInfo>::iterator
87           I = PoolVars.begin(), E = PoolVars.end(); I != E; ++I) {
88      VarDecl *var = I->first;
89      PoolVarInfo &info = I->second;
90
91      // Check that we can handle/rewrite all references of the pool.
92
93      clearRefsIn(info.Dcl, info.Refs);
94      for (SmallVectorImpl<PoolScope>::iterator
95             scpI = info.Scopes.begin(),
96             scpE = info.Scopes.end(); scpI != scpE; ++scpI) {
97        PoolScope &scope = *scpI;
98        clearRefsIn(*scope.Begin, info.Refs);
99        clearRefsIn(*scope.End, info.Refs);
100        clearRefsIn(scope.Releases.begin(), scope.Releases.end(), info.Refs);
101      }
102
103      // Even if one reference is not handled we will not do anything about that
104      // pool variable.
105      if (info.Refs.empty())
106        VarsToHandle.push_back(var);
107    }
108
109    for (unsigned i = 0, e = VarsToHandle.size(); i != e; ++i) {
110      PoolVarInfo &info = PoolVars[VarsToHandle[i]];
111
112      Transaction Trans(Pass.TA);
113
114      clearUnavailableDiags(info.Dcl);
115      Pass.TA.removeStmt(info.Dcl);
116
117      // Add "@autoreleasepool { }"
118      for (SmallVectorImpl<PoolScope>::iterator
119             scpI = info.Scopes.begin(),
120             scpE = info.Scopes.end(); scpI != scpE; ++scpI) {
121        PoolScope &scope = *scpI;
122        clearUnavailableDiags(*scope.Begin);
123        clearUnavailableDiags(*scope.End);
124        if (scope.IsFollowedBySimpleReturnStmt) {
125          // Include the return in the scope.
126          Pass.TA.replaceStmt(*scope.Begin, "@autoreleasepool {");
127          Pass.TA.removeStmt(*scope.End);
128          Stmt::child_iterator retI = scope.End;
129          ++retI;
130          SourceLocation afterSemi =
131              findLocationAfterSemi((*retI)->getEndLoc(), Pass.Ctx);
132          assert(afterSemi.isValid() &&
133                 "Didn't we check before setting IsFollowedBySimpleReturnStmt "
134                 "to true?");
135          Pass.TA.insertAfterToken(afterSemi, "\n}");
136          Pass.TA.increaseIndentation(
137              SourceRange(scope.getIndentedRange().getBegin(),
138                          (*retI)->getEndLoc()),
139              scope.CompoundParent->getBeginLoc());
140        } else {
141          Pass.TA.replaceStmt(*scope.Begin, "@autoreleasepool {");
142          Pass.TA.replaceStmt(*scope.End, "}");
143          Pass.TA.increaseIndentation(scope.getIndentedRange(),
144                                      scope.CompoundParent->getBeginLoc());
145        }
146      }
147
148      // Remove rest of pool var references.
149      for (SmallVectorImpl<PoolScope>::iterator
150             scpI = info.Scopes.begin(),
151             scpE = info.Scopes.end(); scpI != scpE; ++scpI) {
152        PoolScope &scope = *scpI;
153        for (SmallVectorImpl<ObjCMessageExpr *>::iterator
154               relI = scope.Releases.begin(),
155               relE = scope.Releases.end(); relI != relE; ++relI) {
156          clearUnavailableDiags(*relI);
157          Pass.TA.removeStmt(*relI);
158        }
159      }
160    }
161  }
162
163  bool VisitCompoundStmt(CompoundStmt *S) {
164    SmallVector<PoolScope, 4> Scopes;
165
166    for (Stmt::child_iterator
167           I = S->body_begin(), E = S->body_end(); I != E; ++I) {
168      Stmt *child = getEssential(*I);
169      if (DeclStmt *DclS = dyn_cast<DeclStmt>(child)) {
170        if (DclS->isSingleDecl()) {
171          if (VarDecl *VD = dyn_cast<VarDecl>(DclS->getSingleDecl())) {
172            if (isNSAutoreleasePool(VD->getType())) {
173              PoolVarInfo &info = PoolVars[VD];
174              info.Dcl = DclS;
175              collectRefs(VD, S, info.Refs);
176              // Does this statement follow the pattern:
177              // NSAutoreleasePool * pool = [NSAutoreleasePool  new];
178              if (isPoolCreation(VD->getInit())) {
179                Scopes.push_back(PoolScope());
180                Scopes.back().PoolVar = VD;
181                Scopes.back().CompoundParent = S;
182                Scopes.back().Begin = I;
183              }
184            }
185          }
186        }
187      } else if (BinaryOperator *bop = dyn_cast<BinaryOperator>(child)) {
188        if (DeclRefExpr *dref = dyn_cast<DeclRefExpr>(bop->getLHS())) {
189          if (VarDecl *VD = dyn_cast<VarDecl>(dref->getDecl())) {
190            // Does this statement follow the pattern:
191            // pool = [NSAutoreleasePool  new];
192            if (isNSAutoreleasePool(VD->getType()) &&
193                isPoolCreation(bop->getRHS())) {
194              Scopes.push_back(PoolScope());
195              Scopes.back().PoolVar = VD;
196              Scopes.back().CompoundParent = S;
197              Scopes.back().Begin = I;
198            }
199          }
200        }
201      }
202
203      if (Scopes.empty())
204        continue;
205
206      if (isPoolDrain(Scopes.back().PoolVar, child)) {
207        PoolScope &scope = Scopes.back();
208        scope.End = I;
209        handlePoolScope(scope, S);
210        Scopes.pop_back();
211      }
212    }
213    return true;
214  }
215
216private:
217  void clearUnavailableDiags(Stmt *S) {
218    if (S)
219      Pass.TA.clearDiagnostic(diag::err_unavailable,
220                              diag::err_unavailable_message,
221                              S->getSourceRange());
222  }
223
224  struct PoolScope {
225    VarDecl *PoolVar;
226    CompoundStmt *CompoundParent;
227    Stmt::child_iterator Begin;
228    Stmt::child_iterator End;
229    bool IsFollowedBySimpleReturnStmt;
230    SmallVector<ObjCMessageExpr *, 4> Releases;
231
232    PoolScope() : PoolVar(nullptr), CompoundParent(nullptr), Begin(), End(),
233                  IsFollowedBySimpleReturnStmt(false) { }
234
235    SourceRange getIndentedRange() const {
236      Stmt::child_iterator rangeS = Begin;
237      ++rangeS;
238      if (rangeS == End)
239        return SourceRange();
240      Stmt::child_iterator rangeE = Begin;
241      for (Stmt::child_iterator I = rangeS; I != End; ++I)
242        ++rangeE;
243      return SourceRange((*rangeS)->getBeginLoc(), (*rangeE)->getEndLoc());
244    }
245  };
246
247  class NameReferenceChecker : public RecursiveASTVisitor<NameReferenceChecker>{
248    ASTContext &Ctx;
249    SourceRange ScopeRange;
250    SourceLocation &referenceLoc, &declarationLoc;
251
252  public:
253    NameReferenceChecker(ASTContext &ctx, PoolScope &scope,
254                         SourceLocation &referenceLoc,
255                         SourceLocation &declarationLoc)
256      : Ctx(ctx), referenceLoc(referenceLoc),
257        declarationLoc(declarationLoc) {
258      ScopeRange = SourceRange((*scope.Begin)->getBeginLoc(),
259                               (*scope.End)->getBeginLoc());
260    }
261
262    bool VisitDeclRefExpr(DeclRefExpr *E) {
263      return checkRef(E->getLocation(), E->getDecl()->getLocation());
264    }
265
266    bool VisitTypedefTypeLoc(TypedefTypeLoc TL) {
267      return checkRef(TL.getBeginLoc(), TL.getTypedefNameDecl()->getLocation());
268    }
269
270    bool VisitTagTypeLoc(TagTypeLoc TL) {
271      return checkRef(TL.getBeginLoc(), TL.getDecl()->getLocation());
272    }
273
274  private:
275    bool checkRef(SourceLocation refLoc, SourceLocation declLoc) {
276      if (isInScope(declLoc)) {
277        referenceLoc = refLoc;
278        declarationLoc = declLoc;
279        return false;
280      }
281      return true;
282    }
283
284    bool isInScope(SourceLocation loc) {
285      if (loc.isInvalid())
286        return false;
287
288      SourceManager &SM = Ctx.getSourceManager();
289      if (SM.isBeforeInTranslationUnit(loc, ScopeRange.getBegin()))
290        return false;
291      return SM.isBeforeInTranslationUnit(loc, ScopeRange.getEnd());
292    }
293  };
294
295  void handlePoolScope(PoolScope &scope, CompoundStmt *compoundS) {
296    // Check that all names declared inside the scope are not used
297    // outside the scope.
298    {
299      bool nameUsedOutsideScope = false;
300      SourceLocation referenceLoc, declarationLoc;
301      Stmt::child_iterator SI = scope.End, SE = compoundS->body_end();
302      ++SI;
303      // Check if the autoreleasepool scope is followed by a simple return
304      // statement, in which case we will include the return in the scope.
305      if (SI != SE)
306        if (ReturnStmt *retS = dyn_cast<ReturnStmt>(*SI))
307          if ((retS->getRetValue() == nullptr ||
308               isa<DeclRefExpr>(retS->getRetValue()->IgnoreParenCasts())) &&
309              findLocationAfterSemi(retS->getEndLoc(), Pass.Ctx).isValid()) {
310            scope.IsFollowedBySimpleReturnStmt = true;
311            ++SI; // the return will be included in scope, don't check it.
312          }
313
314      for (; SI != SE; ++SI) {
315        nameUsedOutsideScope = !NameReferenceChecker(Pass.Ctx, scope,
316                                                     referenceLoc,
317                                              declarationLoc).TraverseStmt(*SI);
318        if (nameUsedOutsideScope)
319          break;
320      }
321
322      // If not all references were cleared it means some variables/typenames/etc
323      // declared inside the pool scope are used outside of it.
324      // We won't try to rewrite the pool.
325      if (nameUsedOutsideScope) {
326        Pass.TA.reportError("a name is referenced outside the "
327            "NSAutoreleasePool scope that it was declared in", referenceLoc);
328        Pass.TA.reportNote("name declared here", declarationLoc);
329        Pass.TA.reportNote("intended @autoreleasepool scope begins here",
330                           (*scope.Begin)->getBeginLoc());
331        Pass.TA.reportNote("intended @autoreleasepool scope ends here",
332                           (*scope.End)->getBeginLoc());
333        return;
334      }
335    }
336
337    // Collect all releases of the pool; they will be removed.
338    {
339      ReleaseCollector releaseColl(scope.PoolVar, scope.Releases);
340      Stmt::child_iterator I = scope.Begin;
341      ++I;
342      for (; I != scope.End; ++I)
343        releaseColl.TraverseStmt(*I);
344    }
345
346    PoolVars[scope.PoolVar].Scopes.push_back(scope);
347  }
348
349  bool isPoolCreation(Expr *E) {
350    if (!E) return false;
351    E = getEssential(E);
352    ObjCMessageExpr *ME = dyn_cast<ObjCMessageExpr>(E);
353    if (!ME) return false;
354    if (ME->getMethodFamily() == OMF_new &&
355        ME->getReceiverKind() == ObjCMessageExpr::Class &&
356        isNSAutoreleasePool(ME->getReceiverInterface()))
357      return true;
358    if (ME->getReceiverKind() == ObjCMessageExpr::Instance &&
359        ME->getMethodFamily() == OMF_init) {
360      Expr *rec = getEssential(ME->getInstanceReceiver());
361      if (ObjCMessageExpr *recME = dyn_cast_or_null<ObjCMessageExpr>(rec)) {
362        if (recME->getMethodFamily() == OMF_alloc &&
363            recME->getReceiverKind() == ObjCMessageExpr::Class &&
364            isNSAutoreleasePool(recME->getReceiverInterface()))
365          return true;
366      }
367    }
368
369    return false;
370  }
371
372  bool isPoolDrain(VarDecl *poolVar, Stmt *S) {
373    if (!S) return false;
374    S = getEssential(S);
375    ObjCMessageExpr *ME = dyn_cast<ObjCMessageExpr>(S);
376    if (!ME) return false;
377    if (ME->getReceiverKind() == ObjCMessageExpr::Instance) {
378      Expr *rec = getEssential(ME->getInstanceReceiver());
379      if (DeclRefExpr *dref = dyn_cast<DeclRefExpr>(rec))
380        if (dref->getDecl() == poolVar)
381          return ME->getMethodFamily() == OMF_release ||
382                 ME->getSelector() == DrainSel;
383    }
384
385    return false;
386  }
387
388  bool isNSAutoreleasePool(ObjCInterfaceDecl *IDecl) {
389    return IDecl && IDecl->getIdentifier() == PoolII;
390  }
391
392  bool isNSAutoreleasePool(QualType Ty) {
393    QualType pointee = Ty->getPointeeType();
394    if (pointee.isNull())
395      return false;
396    if (const ObjCInterfaceType *interT = pointee->getAs<ObjCInterfaceType>())
397      return isNSAutoreleasePool(interT->getDecl());
398    return false;
399  }
400
401  static Expr *getEssential(Expr *E) {
402    return cast<Expr>(getEssential((Stmt*)E));
403  }
404  static Stmt *getEssential(Stmt *S) {
405    if (FullExpr *FE = dyn_cast<FullExpr>(S))
406      S = FE->getSubExpr();
407    if (Expr *E = dyn_cast<Expr>(S))
408      S = E->IgnoreParenCasts();
409    return S;
410  }
411
412  Stmt *Body;
413  MigrationPass &Pass;
414
415  IdentifierInfo *PoolII;
416  Selector DrainSel;
417
418  struct PoolVarInfo {
419    DeclStmt *Dcl;
420    ExprSet Refs;
421    SmallVector<PoolScope, 2> Scopes;
422
423    PoolVarInfo() : Dcl(nullptr) { }
424  };
425
426  std::map<VarDecl *, PoolVarInfo> PoolVars;
427};
428
429} // anonymous namespace
430
431void trans::rewriteAutoreleasePool(MigrationPass &pass) {
432  BodyTransform<AutoreleasePoolRewriter> trans(pass);
433  trans.TraverseDecl(pass.Ctx.getTranslationUnitDecl());
434}
435