1//===--- TransProtectedScope.cpp - Transformations to ARC mode ------------===// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8// 9// Adds brackets in case statements that "contain" initialization of retaining 10// variable, thus emitting the "switch case is in protected scope" error. 11// 12//===----------------------------------------------------------------------===// 13 14#include "Transforms.h" 15#include "Internals.h" 16#include "clang/AST/ASTContext.h" 17#include "clang/Sema/SemaDiagnostic.h" 18 19using namespace clang; 20using namespace arcmt; 21using namespace trans; 22 23namespace { 24 25class LocalRefsCollector : public RecursiveASTVisitor<LocalRefsCollector> { 26 SmallVectorImpl<DeclRefExpr *> &Refs; 27 28public: 29 LocalRefsCollector(SmallVectorImpl<DeclRefExpr *> &refs) 30 : Refs(refs) { } 31 32 bool VisitDeclRefExpr(DeclRefExpr *E) { 33 if (ValueDecl *D = E->getDecl()) 34 if (D->getDeclContext()->getRedeclContext()->isFunctionOrMethod()) 35 Refs.push_back(E); 36 return true; 37 } 38}; 39 40struct CaseInfo { 41 SwitchCase *SC; 42 SourceRange Range; 43 enum { 44 St_Unchecked, 45 St_CannotFix, 46 St_Fixed 47 } State; 48 49 CaseInfo() : SC(nullptr), State(St_Unchecked) {} 50 CaseInfo(SwitchCase *S, SourceRange Range) 51 : SC(S), Range(Range), State(St_Unchecked) {} 52}; 53 54class CaseCollector : public RecursiveASTVisitor<CaseCollector> { 55 ParentMap &PMap; 56 SmallVectorImpl<CaseInfo> &Cases; 57 58public: 59 CaseCollector(ParentMap &PMap, SmallVectorImpl<CaseInfo> &Cases) 60 : PMap(PMap), Cases(Cases) { } 61 62 bool VisitSwitchStmt(SwitchStmt *S) { 63 SwitchCase *Curr = S->getSwitchCaseList(); 64 if (!Curr) 65 return true; 66 Stmt *Parent = getCaseParent(Curr); 67 Curr = Curr->getNextSwitchCase(); 68 // Make sure all case statements are in the same scope. 69 while (Curr) { 70 if (getCaseParent(Curr) != Parent) 71 return true; 72 Curr = Curr->getNextSwitchCase(); 73 } 74 75 SourceLocation NextLoc = S->getEndLoc(); 76 Curr = S->getSwitchCaseList(); 77 // We iterate over case statements in reverse source-order. 78 while (Curr) { 79 Cases.push_back( 80 CaseInfo(Curr, SourceRange(Curr->getBeginLoc(), NextLoc))); 81 NextLoc = Curr->getBeginLoc(); 82 Curr = Curr->getNextSwitchCase(); 83 } 84 return true; 85 } 86 87 Stmt *getCaseParent(SwitchCase *S) { 88 Stmt *Parent = PMap.getParent(S); 89 while (Parent && (isa<SwitchCase>(Parent) || isa<LabelStmt>(Parent))) 90 Parent = PMap.getParent(Parent); 91 return Parent; 92 } 93}; 94 95class ProtectedScopeFixer { 96 MigrationPass &Pass; 97 SourceManager &SM; 98 SmallVector<CaseInfo, 16> Cases; 99 SmallVector<DeclRefExpr *, 16> LocalRefs; 100 101public: 102 ProtectedScopeFixer(BodyContext &BodyCtx) 103 : Pass(BodyCtx.getMigrationContext().Pass), 104 SM(Pass.Ctx.getSourceManager()) { 105 106 CaseCollector(BodyCtx.getParentMap(), Cases) 107 .TraverseStmt(BodyCtx.getTopStmt()); 108 LocalRefsCollector(LocalRefs).TraverseStmt(BodyCtx.getTopStmt()); 109 110 SourceRange BodyRange = BodyCtx.getTopStmt()->getSourceRange(); 111 const CapturedDiagList &DiagList = Pass.getDiags(); 112 // Copy the diagnostics so we don't have to worry about invaliding iterators 113 // from the diagnostic list. 114 SmallVector<StoredDiagnostic, 16> StoredDiags; 115 StoredDiags.append(DiagList.begin(), DiagList.end()); 116 SmallVectorImpl<StoredDiagnostic>::iterator 117 I = StoredDiags.begin(), E = StoredDiags.end(); 118 while (I != E) { 119 if (I->getID() == diag::err_switch_into_protected_scope && 120 isInRange(I->getLocation(), BodyRange)) { 121 handleProtectedScopeError(I, E); 122 continue; 123 } 124 ++I; 125 } 126 } 127 128 void handleProtectedScopeError( 129 SmallVectorImpl<StoredDiagnostic>::iterator &DiagI, 130 SmallVectorImpl<StoredDiagnostic>::iterator DiagE){ 131 Transaction Trans(Pass.TA); 132 assert(DiagI->getID() == diag::err_switch_into_protected_scope); 133 SourceLocation ErrLoc = DiagI->getLocation(); 134 bool handledAllNotes = true; 135 ++DiagI; 136 for (; DiagI != DiagE && DiagI->getLevel() == DiagnosticsEngine::Note; 137 ++DiagI) { 138 if (!handleProtectedNote(*DiagI)) 139 handledAllNotes = false; 140 } 141 142 if (handledAllNotes) 143 Pass.TA.clearDiagnostic(diag::err_switch_into_protected_scope, ErrLoc); 144 } 145 146 bool handleProtectedNote(const StoredDiagnostic &Diag) { 147 assert(Diag.getLevel() == DiagnosticsEngine::Note); 148 149 for (unsigned i = 0; i != Cases.size(); i++) { 150 CaseInfo &info = Cases[i]; 151 if (isInRange(Diag.getLocation(), info.Range)) { 152 153 if (info.State == CaseInfo::St_Unchecked) 154 tryFixing(info); 155 assert(info.State != CaseInfo::St_Unchecked); 156 157 if (info.State == CaseInfo::St_Fixed) { 158 Pass.TA.clearDiagnostic(Diag.getID(), Diag.getLocation()); 159 return true; 160 } 161 return false; 162 } 163 } 164 165 return false; 166 } 167 168 void tryFixing(CaseInfo &info) { 169 assert(info.State == CaseInfo::St_Unchecked); 170 if (hasVarReferencedOutside(info)) { 171 info.State = CaseInfo::St_CannotFix; 172 return; 173 } 174 175 Pass.TA.insertAfterToken(info.SC->getColonLoc(), " {"); 176 Pass.TA.insert(info.Range.getEnd(), "}\n"); 177 info.State = CaseInfo::St_Fixed; 178 } 179 180 bool hasVarReferencedOutside(CaseInfo &info) { 181 for (unsigned i = 0, e = LocalRefs.size(); i != e; ++i) { 182 DeclRefExpr *DRE = LocalRefs[i]; 183 if (isInRange(DRE->getDecl()->getLocation(), info.Range) && 184 !isInRange(DRE->getLocation(), info.Range)) 185 return true; 186 } 187 return false; 188 } 189 190 bool isInRange(SourceLocation Loc, SourceRange R) { 191 if (Loc.isInvalid()) 192 return false; 193 return !SM.isBeforeInTranslationUnit(Loc, R.getBegin()) && 194 SM.isBeforeInTranslationUnit(Loc, R.getEnd()); 195 } 196}; 197 198} // anonymous namespace 199 200void ProtectedScopeTraverser::traverseBody(BodyContext &BodyCtx) { 201 ProtectedScopeFixer Fix(BodyCtx); 202} 203