1//===--- TransARCAssign.cpp - Transformations to ARC mode -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// makeAssignARCSafe:
10//
11// Add '__strong' where appropriate.
12//
13//  for (id x in collection) {
14//    x = 0;
15//  }
16// ---->
17//  for (__strong id x in collection) {
18//    x = 0;
19//  }
20//
21//===----------------------------------------------------------------------===//
22
23#include "Transforms.h"
24#include "Internals.h"
25#include "clang/AST/ASTContext.h"
26#include "clang/Sema/SemaDiagnostic.h"
27
28using namespace clang;
29using namespace arcmt;
30using namespace trans;
31
32namespace {
33
34class ARCAssignChecker : public RecursiveASTVisitor<ARCAssignChecker> {
35  MigrationPass &Pass;
36  llvm::DenseSet<VarDecl *> ModifiedVars;
37
38public:
39  ARCAssignChecker(MigrationPass &pass) : Pass(pass) { }
40
41  bool VisitBinaryOperator(BinaryOperator *Exp) {
42    if (Exp->getType()->isDependentType())
43      return true;
44
45    Expr *E = Exp->getLHS();
46    SourceLocation OrigLoc = E->getExprLoc();
47    SourceLocation Loc = OrigLoc;
48    DeclRefExpr *declRef = dyn_cast<DeclRefExpr>(E->IgnoreParenCasts());
49    if (declRef && isa<VarDecl>(declRef->getDecl())) {
50      ASTContext &Ctx = Pass.Ctx;
51      Expr::isModifiableLvalueResult IsLV = E->isModifiableLvalue(Ctx, &Loc);
52      if (IsLV != Expr::MLV_ConstQualified)
53        return true;
54      VarDecl *var = cast<VarDecl>(declRef->getDecl());
55      if (var->isARCPseudoStrong()) {
56        Transaction Trans(Pass.TA);
57        if (Pass.TA.clearDiagnostic(diag::err_typecheck_arr_assign_enumeration,
58                                    Exp->getOperatorLoc())) {
59          if (!ModifiedVars.count(var)) {
60            TypeLoc TLoc = var->getTypeSourceInfo()->getTypeLoc();
61            Pass.TA.insert(TLoc.getBeginLoc(), "__strong ");
62            ModifiedVars.insert(var);
63          }
64        }
65      }
66    }
67
68    return true;
69  }
70};
71
72} // anonymous namespace
73
74void trans::makeAssignARCSafe(MigrationPass &pass) {
75  ARCAssignChecker assignCheck(pass);
76  assignCheck.TraverseDecl(pass.Ctx.getTranslationUnitDecl());
77}
78