1326941Sdim//===--- SourceExtraction.cpp - Clang refactoring library -----------------===//
2326941Sdim//
3353358Sdim// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4353358Sdim// See https://llvm.org/LICENSE.txt for license information.
5353358Sdim// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6326941Sdim//
7326941Sdim//===----------------------------------------------------------------------===//
8326941Sdim
9360784Sdim#include "clang/Tooling/Refactoring/Extract/SourceExtraction.h"
10326941Sdim#include "clang/AST/Stmt.h"
11326941Sdim#include "clang/AST/StmtCXX.h"
12326941Sdim#include "clang/AST/StmtObjC.h"
13326941Sdim#include "clang/Basic/SourceManager.h"
14326941Sdim#include "clang/Lex/Lexer.h"
15326941Sdim
16326941Sdimusing namespace clang;
17326941Sdim
18326941Sdimnamespace {
19326941Sdim
20326941Sdim/// Returns true if the token at the given location is a semicolon.
21326941Sdimbool isSemicolonAtLocation(SourceLocation TokenLoc, const SourceManager &SM,
22326941Sdim                           const LangOptions &LangOpts) {
23326941Sdim  return Lexer::getSourceText(
24326941Sdim             CharSourceRange::getTokenRange(TokenLoc, TokenLoc), SM,
25326941Sdim             LangOpts) == ";";
26326941Sdim}
27326941Sdim
28326941Sdim/// Returns true if there should be a semicolon after the given statement.
29326941Sdimbool isSemicolonRequiredAfter(const Stmt *S) {
30326941Sdim  if (isa<CompoundStmt>(S))
31326941Sdim    return false;
32326941Sdim  if (const auto *If = dyn_cast<IfStmt>(S))
33326941Sdim    return isSemicolonRequiredAfter(If->getElse() ? If->getElse()
34326941Sdim                                                  : If->getThen());
35326941Sdim  if (const auto *While = dyn_cast<WhileStmt>(S))
36326941Sdim    return isSemicolonRequiredAfter(While->getBody());
37326941Sdim  if (const auto *For = dyn_cast<ForStmt>(S))
38326941Sdim    return isSemicolonRequiredAfter(For->getBody());
39326941Sdim  if (const auto *CXXFor = dyn_cast<CXXForRangeStmt>(S))
40326941Sdim    return isSemicolonRequiredAfter(CXXFor->getBody());
41326941Sdim  if (const auto *ObjCFor = dyn_cast<ObjCForCollectionStmt>(S))
42326941Sdim    return isSemicolonRequiredAfter(ObjCFor->getBody());
43360784Sdim  if(const auto *Switch = dyn_cast<SwitchStmt>(S))
44360784Sdim    return isSemicolonRequiredAfter(Switch->getBody());
45360784Sdim  if(const auto *Case = dyn_cast<SwitchCase>(S))
46360784Sdim    return isSemicolonRequiredAfter(Case->getSubStmt());
47326941Sdim  switch (S->getStmtClass()) {
48360784Sdim  case Stmt::DeclStmtClass:
49326941Sdim  case Stmt::CXXTryStmtClass:
50326941Sdim  case Stmt::ObjCAtSynchronizedStmtClass:
51326941Sdim  case Stmt::ObjCAutoreleasePoolStmtClass:
52326941Sdim  case Stmt::ObjCAtTryStmtClass:
53326941Sdim    return false;
54326941Sdim  default:
55326941Sdim    return true;
56326941Sdim  }
57326941Sdim}
58326941Sdim
59326941Sdim/// Returns true if the two source locations are on the same line.
60326941Sdimbool areOnSameLine(SourceLocation Loc1, SourceLocation Loc2,
61326941Sdim                   const SourceManager &SM) {
62326941Sdim  return !Loc1.isMacroID() && !Loc2.isMacroID() &&
63326941Sdim         SM.getSpellingLineNumber(Loc1) == SM.getSpellingLineNumber(Loc2);
64326941Sdim}
65326941Sdim
66326941Sdim} // end anonymous namespace
67326941Sdim
68326941Sdimnamespace clang {
69326941Sdimnamespace tooling {
70326941Sdim
71326941SdimExtractionSemicolonPolicy
72326941SdimExtractionSemicolonPolicy::compute(const Stmt *S, SourceRange &ExtractedRange,
73326941Sdim                                   const SourceManager &SM,
74326941Sdim                                   const LangOptions &LangOpts) {
75326941Sdim  auto neededInExtractedFunction = []() {
76326941Sdim    return ExtractionSemicolonPolicy(true, false);
77326941Sdim  };
78326941Sdim  auto neededInOriginalFunction = []() {
79326941Sdim    return ExtractionSemicolonPolicy(false, true);
80326941Sdim  };
81326941Sdim
82326941Sdim  /// The extracted expression should be terminated with a ';'. The call to
83326941Sdim  /// the extracted function will replace this expression, so it won't need
84326941Sdim  /// a terminating ';'.
85326941Sdim  if (isa<Expr>(S))
86326941Sdim    return neededInExtractedFunction();
87326941Sdim
88326941Sdim  /// Some statements don't need to be terminated with ';'. The call to the
89326941Sdim  /// extracted function will be a standalone statement, so it should be
90326941Sdim  /// terminated with a ';'.
91326941Sdim  bool NeedsSemi = isSemicolonRequiredAfter(S);
92326941Sdim  if (!NeedsSemi)
93326941Sdim    return neededInOriginalFunction();
94326941Sdim
95326941Sdim  /// Some statements might end at ';'. The extraction will move that ';', so
96326941Sdim  /// the call to the extracted function should be terminated with a ';'.
97326941Sdim  SourceLocation End = ExtractedRange.getEnd();
98326941Sdim  if (isSemicolonAtLocation(End, SM, LangOpts))
99326941Sdim    return neededInOriginalFunction();
100326941Sdim
101326941Sdim  /// Other statements should generally have a trailing ';'. We can try to find
102326941Sdim  /// it and move it together it with the extracted code.
103326941Sdim  Optional<Token> NextToken = Lexer::findNextToken(End, SM, LangOpts);
104326941Sdim  if (NextToken && NextToken->is(tok::semi) &&
105326941Sdim      areOnSameLine(NextToken->getLocation(), End, SM)) {
106326941Sdim    ExtractedRange.setEnd(NextToken->getLocation());
107326941Sdim    return neededInOriginalFunction();
108326941Sdim  }
109326941Sdim
110326941Sdim  /// Otherwise insert semicolons in both places.
111326941Sdim  return ExtractionSemicolonPolicy(true, true);
112326941Sdim}
113326941Sdim
114326941Sdim} // end namespace tooling
115326941Sdim} // end namespace clang
116