1//===--- TransProtectedScope.cpp - Transformations to ARC mode ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Adds brackets in case statements that "contain" initialization of retaining
10// variable, thus emitting the "switch case is in protected scope" error.
11//
12//===----------------------------------------------------------------------===//
13
14#include "Transforms.h"
15#include "Internals.h"
16#include "clang/AST/ASTContext.h"
17#include "clang/Sema/SemaDiagnostic.h"
18
19using namespace clang;
20using namespace arcmt;
21using namespace trans;
22
23namespace {
24
25class LocalRefsCollector : public RecursiveASTVisitor<LocalRefsCollector> {
26  SmallVectorImpl<DeclRefExpr *> &Refs;
27
28public:
29  LocalRefsCollector(SmallVectorImpl<DeclRefExpr *> &refs)
30    : Refs(refs) { }
31
32  bool VisitDeclRefExpr(DeclRefExpr *E) {
33    if (ValueDecl *D = E->getDecl())
34      if (D->getDeclContext()->getRedeclContext()->isFunctionOrMethod())
35        Refs.push_back(E);
36    return true;
37  }
38};
39
40struct CaseInfo {
41  SwitchCase *SC;
42  SourceRange Range;
43  enum {
44    St_Unchecked,
45    St_CannotFix,
46    St_Fixed
47  } State;
48
49  CaseInfo() : SC(nullptr), State(St_Unchecked) {}
50  CaseInfo(SwitchCase *S, SourceRange Range)
51    : SC(S), Range(Range), State(St_Unchecked) {}
52};
53
54class CaseCollector : public RecursiveASTVisitor<CaseCollector> {
55  ParentMap &PMap;
56  SmallVectorImpl<CaseInfo> &Cases;
57
58public:
59  CaseCollector(ParentMap &PMap, SmallVectorImpl<CaseInfo> &Cases)
60    : PMap(PMap), Cases(Cases) { }
61
62  bool VisitSwitchStmt(SwitchStmt *S) {
63    SwitchCase *Curr = S->getSwitchCaseList();
64    if (!Curr)
65      return true;
66    Stmt *Parent = getCaseParent(Curr);
67    Curr = Curr->getNextSwitchCase();
68    // Make sure all case statements are in the same scope.
69    while (Curr) {
70      if (getCaseParent(Curr) != Parent)
71        return true;
72      Curr = Curr->getNextSwitchCase();
73    }
74
75    SourceLocation NextLoc = S->getEndLoc();
76    Curr = S->getSwitchCaseList();
77    // We iterate over case statements in reverse source-order.
78    while (Curr) {
79      Cases.push_back(
80          CaseInfo(Curr, SourceRange(Curr->getBeginLoc(), NextLoc)));
81      NextLoc = Curr->getBeginLoc();
82      Curr = Curr->getNextSwitchCase();
83    }
84    return true;
85  }
86
87  Stmt *getCaseParent(SwitchCase *S) {
88    Stmt *Parent = PMap.getParent(S);
89    while (Parent && (isa<SwitchCase>(Parent) || isa<LabelStmt>(Parent)))
90      Parent = PMap.getParent(Parent);
91    return Parent;
92  }
93};
94
95class ProtectedScopeFixer {
96  MigrationPass &Pass;
97  SourceManager &SM;
98  SmallVector<CaseInfo, 16> Cases;
99  SmallVector<DeclRefExpr *, 16> LocalRefs;
100
101public:
102  ProtectedScopeFixer(BodyContext &BodyCtx)
103    : Pass(BodyCtx.getMigrationContext().Pass),
104      SM(Pass.Ctx.getSourceManager()) {
105
106    CaseCollector(BodyCtx.getParentMap(), Cases)
107        .TraverseStmt(BodyCtx.getTopStmt());
108    LocalRefsCollector(LocalRefs).TraverseStmt(BodyCtx.getTopStmt());
109
110    SourceRange BodyRange = BodyCtx.getTopStmt()->getSourceRange();
111    const CapturedDiagList &DiagList = Pass.getDiags();
112    // Copy the diagnostics so we don't have to worry about invaliding iterators
113    // from the diagnostic list.
114    SmallVector<StoredDiagnostic, 16> StoredDiags;
115    StoredDiags.append(DiagList.begin(), DiagList.end());
116    SmallVectorImpl<StoredDiagnostic>::iterator
117        I = StoredDiags.begin(), E = StoredDiags.end();
118    while (I != E) {
119      if (I->getID() == diag::err_switch_into_protected_scope &&
120          isInRange(I->getLocation(), BodyRange)) {
121        handleProtectedScopeError(I, E);
122        continue;
123      }
124      ++I;
125    }
126  }
127
128  void handleProtectedScopeError(
129                             SmallVectorImpl<StoredDiagnostic>::iterator &DiagI,
130                             SmallVectorImpl<StoredDiagnostic>::iterator DiagE){
131    Transaction Trans(Pass.TA);
132    assert(DiagI->getID() == diag::err_switch_into_protected_scope);
133    SourceLocation ErrLoc = DiagI->getLocation();
134    bool handledAllNotes = true;
135    ++DiagI;
136    for (; DiagI != DiagE && DiagI->getLevel() == DiagnosticsEngine::Note;
137         ++DiagI) {
138      if (!handleProtectedNote(*DiagI))
139        handledAllNotes = false;
140    }
141
142    if (handledAllNotes)
143      Pass.TA.clearDiagnostic(diag::err_switch_into_protected_scope, ErrLoc);
144  }
145
146  bool handleProtectedNote(const StoredDiagnostic &Diag) {
147    assert(Diag.getLevel() == DiagnosticsEngine::Note);
148
149    for (unsigned i = 0; i != Cases.size(); i++) {
150      CaseInfo &info = Cases[i];
151      if (isInRange(Diag.getLocation(), info.Range)) {
152
153        if (info.State == CaseInfo::St_Unchecked)
154          tryFixing(info);
155        assert(info.State != CaseInfo::St_Unchecked);
156
157        if (info.State == CaseInfo::St_Fixed) {
158          Pass.TA.clearDiagnostic(Diag.getID(), Diag.getLocation());
159          return true;
160        }
161        return false;
162      }
163    }
164
165    return false;
166  }
167
168  void tryFixing(CaseInfo &info) {
169    assert(info.State == CaseInfo::St_Unchecked);
170    if (hasVarReferencedOutside(info)) {
171      info.State = CaseInfo::St_CannotFix;
172      return;
173    }
174
175    Pass.TA.insertAfterToken(info.SC->getColonLoc(), " {");
176    Pass.TA.insert(info.Range.getEnd(), "}\n");
177    info.State = CaseInfo::St_Fixed;
178  }
179
180  bool hasVarReferencedOutside(CaseInfo &info) {
181    for (unsigned i = 0, e = LocalRefs.size(); i != e; ++i) {
182      DeclRefExpr *DRE = LocalRefs[i];
183      if (isInRange(DRE->getDecl()->getLocation(), info.Range) &&
184          !isInRange(DRE->getLocation(), info.Range))
185        return true;
186    }
187    return false;
188  }
189
190  bool isInRange(SourceLocation Loc, SourceRange R) {
191    if (Loc.isInvalid())
192      return false;
193    return !SM.isBeforeInTranslationUnit(Loc, R.getBegin()) &&
194            SM.isBeforeInTranslationUnit(Loc, R.getEnd());
195  }
196};
197
198} // anonymous namespace
199
200void ProtectedScopeTraverser::traverseBody(BodyContext &BodyCtx) {
201  ProtectedScopeFixer Fix(BodyCtx);
202}
203