1249261Sdim//===--- TransProtectedScope.cpp - Transformations to ARC mode ------------===//
2249261Sdim//
3249261Sdim//                     The LLVM Compiler Infrastructure
4249261Sdim//
5249261Sdim// This file is distributed under the University of Illinois Open Source
6249261Sdim// License. See LICENSE.TXT for details.
7249261Sdim//
8249261Sdim//===----------------------------------------------------------------------===//
9249261Sdim//
10249261Sdim// Adds brackets in case statements that "contain" initialization of retaining
11249261Sdim// variable, thus emitting the "switch case is in protected scope" error.
12249261Sdim//
13249261Sdim//===----------------------------------------------------------------------===//
14249261Sdim
15249261Sdim#include "Transforms.h"
16249261Sdim#include "Internals.h"
17249261Sdim#include "clang/AST/ASTContext.h"
18249261Sdim#include "clang/Sema/SemaDiagnostic.h"
19249261Sdim
20249261Sdimusing namespace clang;
21249261Sdimusing namespace arcmt;
22249261Sdimusing namespace trans;
23249261Sdim
24249261Sdimnamespace {
25249261Sdim
26249261Sdimclass LocalRefsCollector : public RecursiveASTVisitor<LocalRefsCollector> {
27249261Sdim  SmallVectorImpl<DeclRefExpr *> &Refs;
28249261Sdim
29249261Sdimpublic:
30249261Sdim  LocalRefsCollector(SmallVectorImpl<DeclRefExpr *> &refs)
31249261Sdim    : Refs(refs) { }
32249261Sdim
33249261Sdim  bool VisitDeclRefExpr(DeclRefExpr *E) {
34249261Sdim    if (ValueDecl *D = E->getDecl())
35249261Sdim      if (D->getDeclContext()->getRedeclContext()->isFunctionOrMethod())
36249261Sdim        Refs.push_back(E);
37249261Sdim    return true;
38249261Sdim  }
39249261Sdim};
40249261Sdim
41249261Sdimstruct CaseInfo {
42249261Sdim  SwitchCase *SC;
43249261Sdim  SourceRange Range;
44249261Sdim  enum {
45249261Sdim    St_Unchecked,
46249261Sdim    St_CannotFix,
47249261Sdim    St_Fixed
48249261Sdim  } State;
49249261Sdim
50249261Sdim  CaseInfo() : SC(0), State(St_Unchecked) {}
51249261Sdim  CaseInfo(SwitchCase *S, SourceRange Range)
52249261Sdim    : SC(S), Range(Range), State(St_Unchecked) {}
53249261Sdim};
54249261Sdim
55249261Sdimclass CaseCollector : public RecursiveASTVisitor<CaseCollector> {
56249261Sdim  ParentMap &PMap;
57249261Sdim  SmallVectorImpl<CaseInfo> &Cases;
58249261Sdim
59249261Sdimpublic:
60249261Sdim  CaseCollector(ParentMap &PMap, SmallVectorImpl<CaseInfo> &Cases)
61249261Sdim    : PMap(PMap), Cases(Cases) { }
62249261Sdim
63249261Sdim  bool VisitSwitchStmt(SwitchStmt *S) {
64249261Sdim    SwitchCase *Curr = S->getSwitchCaseList();
65249261Sdim    if (!Curr)
66249261Sdim      return true;
67249261Sdim    Stmt *Parent = getCaseParent(Curr);
68249261Sdim    Curr = Curr->getNextSwitchCase();
69249261Sdim    // Make sure all case statements are in the same scope.
70249261Sdim    while (Curr) {
71249261Sdim      if (getCaseParent(Curr) != Parent)
72249261Sdim        return true;
73249261Sdim      Curr = Curr->getNextSwitchCase();
74249261Sdim    }
75249261Sdim
76249261Sdim    SourceLocation NextLoc = S->getLocEnd();
77249261Sdim    Curr = S->getSwitchCaseList();
78249261Sdim    // We iterate over case statements in reverse source-order.
79249261Sdim    while (Curr) {
80249261Sdim      Cases.push_back(CaseInfo(Curr,SourceRange(Curr->getLocStart(), NextLoc)));
81249261Sdim      NextLoc = Curr->getLocStart();
82249261Sdim      Curr = Curr->getNextSwitchCase();
83249261Sdim    }
84249261Sdim    return true;
85249261Sdim  }
86249261Sdim
87249261Sdim  Stmt *getCaseParent(SwitchCase *S) {
88249261Sdim    Stmt *Parent = PMap.getParent(S);
89249261Sdim    while (Parent && (isa<SwitchCase>(Parent) || isa<LabelStmt>(Parent)))
90249261Sdim      Parent = PMap.getParent(Parent);
91249261Sdim    return Parent;
92249261Sdim  }
93249261Sdim};
94249261Sdim
95249261Sdimclass ProtectedScopeFixer {
96249261Sdim  MigrationPass &Pass;
97249261Sdim  SourceManager &SM;
98249261Sdim  SmallVector<CaseInfo, 16> Cases;
99249261Sdim  SmallVector<DeclRefExpr *, 16> LocalRefs;
100249261Sdim
101249261Sdimpublic:
102249261Sdim  ProtectedScopeFixer(BodyContext &BodyCtx)
103249261Sdim    : Pass(BodyCtx.getMigrationContext().Pass),
104249261Sdim      SM(Pass.Ctx.getSourceManager()) {
105249261Sdim
106249261Sdim    CaseCollector(BodyCtx.getParentMap(), Cases)
107249261Sdim        .TraverseStmt(BodyCtx.getTopStmt());
108249261Sdim    LocalRefsCollector(LocalRefs).TraverseStmt(BodyCtx.getTopStmt());
109249261Sdim
110249261Sdim    SourceRange BodyRange = BodyCtx.getTopStmt()->getSourceRange();
111249261Sdim    const CapturedDiagList &DiagList = Pass.getDiags();
112249261Sdim    // Copy the diagnostics so we don't have to worry about invaliding iterators
113249261Sdim    // from the diagnostic list.
114249261Sdim    SmallVector<StoredDiagnostic, 16> StoredDiags;
115249261Sdim    StoredDiags.append(DiagList.begin(), DiagList.end());
116249261Sdim    SmallVectorImpl<StoredDiagnostic>::iterator
117249261Sdim        I = StoredDiags.begin(), E = StoredDiags.end();
118249261Sdim    while (I != E) {
119249261Sdim      if (I->getID() == diag::err_switch_into_protected_scope &&
120249261Sdim          isInRange(I->getLocation(), BodyRange)) {
121249261Sdim        handleProtectedScopeError(I, E);
122249261Sdim        continue;
123249261Sdim      }
124249261Sdim      ++I;
125249261Sdim    }
126249261Sdim  }
127249261Sdim
128249261Sdim  void handleProtectedScopeError(
129249261Sdim                             SmallVectorImpl<StoredDiagnostic>::iterator &DiagI,
130249261Sdim                             SmallVectorImpl<StoredDiagnostic>::iterator DiagE){
131249261Sdim    Transaction Trans(Pass.TA);
132249261Sdim    assert(DiagI->getID() == diag::err_switch_into_protected_scope);
133249261Sdim    SourceLocation ErrLoc = DiagI->getLocation();
134249261Sdim    bool handledAllNotes = true;
135249261Sdim    ++DiagI;
136249261Sdim    for (; DiagI != DiagE && DiagI->getLevel() == DiagnosticsEngine::Note;
137249261Sdim         ++DiagI) {
138249261Sdim      if (!handleProtectedNote(*DiagI))
139249261Sdim        handledAllNotes = false;
140249261Sdim    }
141249261Sdim
142249261Sdim    if (handledAllNotes)
143249261Sdim      Pass.TA.clearDiagnostic(diag::err_switch_into_protected_scope, ErrLoc);
144249261Sdim  }
145249261Sdim
146249261Sdim  bool handleProtectedNote(const StoredDiagnostic &Diag) {
147249261Sdim    assert(Diag.getLevel() == DiagnosticsEngine::Note);
148249261Sdim
149249261Sdim    for (unsigned i = 0; i != Cases.size(); i++) {
150249261Sdim      CaseInfo &info = Cases[i];
151249261Sdim      if (isInRange(Diag.getLocation(), info.Range)) {
152249261Sdim
153249261Sdim        if (info.State == CaseInfo::St_Unchecked)
154249261Sdim          tryFixing(info);
155249261Sdim        assert(info.State != CaseInfo::St_Unchecked);
156249261Sdim
157249261Sdim        if (info.State == CaseInfo::St_Fixed) {
158249261Sdim          Pass.TA.clearDiagnostic(Diag.getID(), Diag.getLocation());
159249261Sdim          return true;
160249261Sdim        }
161249261Sdim        return false;
162249261Sdim      }
163249261Sdim    }
164249261Sdim
165249261Sdim    return false;
166249261Sdim  }
167249261Sdim
168249261Sdim  void tryFixing(CaseInfo &info) {
169249261Sdim    assert(info.State == CaseInfo::St_Unchecked);
170249261Sdim    if (hasVarReferencedOutside(info)) {
171249261Sdim      info.State = CaseInfo::St_CannotFix;
172249261Sdim      return;
173249261Sdim    }
174249261Sdim
175249261Sdim    Pass.TA.insertAfterToken(info.SC->getColonLoc(), " {");
176249261Sdim    Pass.TA.insert(info.Range.getEnd(), "}\n");
177249261Sdim    info.State = CaseInfo::St_Fixed;
178249261Sdim  }
179249261Sdim
180249261Sdim  bool hasVarReferencedOutside(CaseInfo &info) {
181249261Sdim    for (unsigned i = 0, e = LocalRefs.size(); i != e; ++i) {
182249261Sdim      DeclRefExpr *DRE = LocalRefs[i];
183249261Sdim      if (isInRange(DRE->getDecl()->getLocation(), info.Range) &&
184249261Sdim          !isInRange(DRE->getLocation(), info.Range))
185249261Sdim        return true;
186249261Sdim    }
187249261Sdim    return false;
188249261Sdim  }
189249261Sdim
190249261Sdim  bool isInRange(SourceLocation Loc, SourceRange R) {
191249261Sdim    if (Loc.isInvalid())
192249261Sdim      return false;
193249261Sdim    return !SM.isBeforeInTranslationUnit(Loc, R.getBegin()) &&
194249261Sdim            SM.isBeforeInTranslationUnit(Loc, R.getEnd());
195249261Sdim  }
196249261Sdim};
197249261Sdim
198249261Sdim} // anonymous namespace
199249261Sdim
200249261Sdimvoid ProtectedScopeTraverser::traverseBody(BodyContext &BodyCtx) {
201249261Sdim  ProtectedScopeFixer Fix(BodyCtx);
202249261Sdim}
203