1//===--- TransProtectedScope.cpp - Transformations to ARC mode ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Adds brackets in case statements that "contain" initialization of retaining
10// variable, thus emitting the "switch case is in protected scope" error.
11//
12//===----------------------------------------------------------------------===//
13
14#include "Internals.h"
15#include "Transforms.h"
16#include "clang/AST/ASTContext.h"
17#include "clang/Basic/SourceManager.h"
18#include "clang/Sema/SemaDiagnostic.h"
19
20using namespace clang;
21using namespace arcmt;
22using namespace trans;
23
24namespace {
25
26class LocalRefsCollector : public RecursiveASTVisitor<LocalRefsCollector> {
27  SmallVectorImpl<DeclRefExpr *> &Refs;
28
29public:
30  LocalRefsCollector(SmallVectorImpl<DeclRefExpr *> &refs)
31    : Refs(refs) { }
32
33  bool VisitDeclRefExpr(DeclRefExpr *E) {
34    if (ValueDecl *D = E->getDecl())
35      if (D->getDeclContext()->getRedeclContext()->isFunctionOrMethod())
36        Refs.push_back(E);
37    return true;
38  }
39};
40
41struct CaseInfo {
42  SwitchCase *SC;
43  SourceRange Range;
44  enum {
45    St_Unchecked,
46    St_CannotFix,
47    St_Fixed
48  } State;
49
50  CaseInfo() : SC(nullptr), State(St_Unchecked) {}
51  CaseInfo(SwitchCase *S, SourceRange Range)
52    : SC(S), Range(Range), State(St_Unchecked) {}
53};
54
55class CaseCollector : public RecursiveASTVisitor<CaseCollector> {
56  ParentMap &PMap;
57  SmallVectorImpl<CaseInfo> &Cases;
58
59public:
60  CaseCollector(ParentMap &PMap, SmallVectorImpl<CaseInfo> &Cases)
61    : PMap(PMap), Cases(Cases) { }
62
63  bool VisitSwitchStmt(SwitchStmt *S) {
64    SwitchCase *Curr = S->getSwitchCaseList();
65    if (!Curr)
66      return true;
67    Stmt *Parent = getCaseParent(Curr);
68    Curr = Curr->getNextSwitchCase();
69    // Make sure all case statements are in the same scope.
70    while (Curr) {
71      if (getCaseParent(Curr) != Parent)
72        return true;
73      Curr = Curr->getNextSwitchCase();
74    }
75
76    SourceLocation NextLoc = S->getEndLoc();
77    Curr = S->getSwitchCaseList();
78    // We iterate over case statements in reverse source-order.
79    while (Curr) {
80      Cases.push_back(
81          CaseInfo(Curr, SourceRange(Curr->getBeginLoc(), NextLoc)));
82      NextLoc = Curr->getBeginLoc();
83      Curr = Curr->getNextSwitchCase();
84    }
85    return true;
86  }
87
88  Stmt *getCaseParent(SwitchCase *S) {
89    Stmt *Parent = PMap.getParent(S);
90    while (Parent && (isa<SwitchCase>(Parent) || isa<LabelStmt>(Parent)))
91      Parent = PMap.getParent(Parent);
92    return Parent;
93  }
94};
95
96class ProtectedScopeFixer {
97  MigrationPass &Pass;
98  SourceManager &SM;
99  SmallVector<CaseInfo, 16> Cases;
100  SmallVector<DeclRefExpr *, 16> LocalRefs;
101
102public:
103  ProtectedScopeFixer(BodyContext &BodyCtx)
104    : Pass(BodyCtx.getMigrationContext().Pass),
105      SM(Pass.Ctx.getSourceManager()) {
106
107    CaseCollector(BodyCtx.getParentMap(), Cases)
108        .TraverseStmt(BodyCtx.getTopStmt());
109    LocalRefsCollector(LocalRefs).TraverseStmt(BodyCtx.getTopStmt());
110
111    SourceRange BodyRange = BodyCtx.getTopStmt()->getSourceRange();
112    const CapturedDiagList &DiagList = Pass.getDiags();
113    // Copy the diagnostics so we don't have to worry about invaliding iterators
114    // from the diagnostic list.
115    SmallVector<StoredDiagnostic, 16> StoredDiags;
116    StoredDiags.append(DiagList.begin(), DiagList.end());
117    SmallVectorImpl<StoredDiagnostic>::iterator
118        I = StoredDiags.begin(), E = StoredDiags.end();
119    while (I != E) {
120      if (I->getID() == diag::err_switch_into_protected_scope &&
121          isInRange(I->getLocation(), BodyRange)) {
122        handleProtectedScopeError(I, E);
123        continue;
124      }
125      ++I;
126    }
127  }
128
129  void handleProtectedScopeError(
130                             SmallVectorImpl<StoredDiagnostic>::iterator &DiagI,
131                             SmallVectorImpl<StoredDiagnostic>::iterator DiagE){
132    Transaction Trans(Pass.TA);
133    assert(DiagI->getID() == diag::err_switch_into_protected_scope);
134    SourceLocation ErrLoc = DiagI->getLocation();
135    bool handledAllNotes = true;
136    ++DiagI;
137    for (; DiagI != DiagE && DiagI->getLevel() == DiagnosticsEngine::Note;
138         ++DiagI) {
139      if (!handleProtectedNote(*DiagI))
140        handledAllNotes = false;
141    }
142
143    if (handledAllNotes)
144      Pass.TA.clearDiagnostic(diag::err_switch_into_protected_scope, ErrLoc);
145  }
146
147  bool handleProtectedNote(const StoredDiagnostic &Diag) {
148    assert(Diag.getLevel() == DiagnosticsEngine::Note);
149
150    for (unsigned i = 0; i != Cases.size(); i++) {
151      CaseInfo &info = Cases[i];
152      if (isInRange(Diag.getLocation(), info.Range)) {
153
154        if (info.State == CaseInfo::St_Unchecked)
155          tryFixing(info);
156        assert(info.State != CaseInfo::St_Unchecked);
157
158        if (info.State == CaseInfo::St_Fixed) {
159          Pass.TA.clearDiagnostic(Diag.getID(), Diag.getLocation());
160          return true;
161        }
162        return false;
163      }
164    }
165
166    return false;
167  }
168
169  void tryFixing(CaseInfo &info) {
170    assert(info.State == CaseInfo::St_Unchecked);
171    if (hasVarReferencedOutside(info)) {
172      info.State = CaseInfo::St_CannotFix;
173      return;
174    }
175
176    Pass.TA.insertAfterToken(info.SC->getColonLoc(), " {");
177    Pass.TA.insert(info.Range.getEnd(), "}\n");
178    info.State = CaseInfo::St_Fixed;
179  }
180
181  bool hasVarReferencedOutside(CaseInfo &info) {
182    for (unsigned i = 0, e = LocalRefs.size(); i != e; ++i) {
183      DeclRefExpr *DRE = LocalRefs[i];
184      if (isInRange(DRE->getDecl()->getLocation(), info.Range) &&
185          !isInRange(DRE->getLocation(), info.Range))
186        return true;
187    }
188    return false;
189  }
190
191  bool isInRange(SourceLocation Loc, SourceRange R) {
192    if (Loc.isInvalid())
193      return false;
194    return !SM.isBeforeInTranslationUnit(Loc, R.getBegin()) &&
195            SM.isBeforeInTranslationUnit(Loc, R.getEnd());
196  }
197};
198
199} // anonymous namespace
200
201void ProtectedScopeTraverser::traverseBody(BodyContext &BodyCtx) {
202  ProtectedScopeFixer Fix(BodyCtx);
203}
204