1//===--- Stencil.cpp - Stencil implementation -------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "clang/Tooling/Transformer/Stencil.h"
10#include "clang/AST/ASTContext.h"
11#include "clang/AST/ASTTypeTraits.h"
12#include "clang/AST/Expr.h"
13#include "clang/ASTMatchers/ASTMatchFinder.h"
14#include "clang/Basic/SourceLocation.h"
15#include "clang/Lex/Lexer.h"
16#include "clang/Tooling/Transformer/SourceCode.h"
17#include "clang/Tooling/Transformer/SourceCodeBuilders.h"
18#include "llvm/ADT/SmallVector.h"
19#include "llvm/ADT/Twine.h"
20#include "llvm/Support/Errc.h"
21#include "llvm/Support/Error.h"
22#include <atomic>
23#include <memory>
24#include <string>
25
26using namespace clang;
27using namespace transformer;
28
29using ast_matchers::BoundNodes;
30using ast_matchers::MatchFinder;
31using llvm::errc;
32using llvm::Error;
33using llvm::Expected;
34using llvm::StringError;
35
36static llvm::Expected<DynTypedNode> getNode(const BoundNodes &Nodes,
37                                            StringRef Id) {
38  auto &NodesMap = Nodes.getMap();
39  auto It = NodesMap.find(Id);
40  if (It == NodesMap.end())
41    return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,
42                                               "Id not bound: " + Id);
43  return It->second;
44}
45
46static Error printNode(StringRef Id, const MatchFinder::MatchResult &Match,
47                       std::string *Result) {
48  std::string Output;
49  llvm::raw_string_ostream Os(Output);
50  auto NodeOrErr = getNode(Match.Nodes, Id);
51  if (auto Err = NodeOrErr.takeError())
52    return Err;
53  NodeOrErr->print(Os, PrintingPolicy(Match.Context->getLangOpts()));
54  *Result += Os.str();
55  return Error::success();
56}
57
58namespace {
59// An arbitrary fragment of code within a stencil.
60class RawTextStencil : public StencilInterface {
61  std::string Text;
62
63public:
64  explicit RawTextStencil(std::string T) : Text(std::move(T)) {}
65
66  std::string toString() const override {
67    std::string Result;
68    llvm::raw_string_ostream OS(Result);
69    OS << "\"";
70    OS.write_escaped(Text);
71    OS << "\"";
72    OS.flush();
73    return Result;
74  }
75
76  Error eval(const MatchFinder::MatchResult &Match,
77             std::string *Result) const override {
78    Result->append(Text);
79    return Error::success();
80  }
81};
82
83// A debugging operation to dump the AST for a particular (bound) AST node.
84class DebugPrintNodeStencil : public StencilInterface {
85  std::string Id;
86
87public:
88  explicit DebugPrintNodeStencil(std::string S) : Id(std::move(S)) {}
89
90  std::string toString() const override {
91    return (llvm::Twine("dPrint(\"") + Id + "\")").str();
92  }
93
94  Error eval(const MatchFinder::MatchResult &Match,
95             std::string *Result) const override {
96    return printNode(Id, Match, Result);
97  }
98};
99
100// Operators that take a single node Id as an argument.
101enum class UnaryNodeOperator {
102  Parens,
103  Deref,
104  MaybeDeref,
105  AddressOf,
106  MaybeAddressOf,
107  Describe,
108};
109
110// Generic container for stencil operations with a (single) node-id argument.
111class UnaryOperationStencil : public StencilInterface {
112  UnaryNodeOperator Op;
113  std::string Id;
114
115public:
116  UnaryOperationStencil(UnaryNodeOperator Op, std::string Id)
117      : Op(Op), Id(std::move(Id)) {}
118
119  std::string toString() const override {
120    StringRef OpName;
121    switch (Op) {
122    case UnaryNodeOperator::Parens:
123      OpName = "expression";
124      break;
125    case UnaryNodeOperator::Deref:
126      OpName = "deref";
127      break;
128    case UnaryNodeOperator::MaybeDeref:
129      OpName = "maybeDeref";
130      break;
131    case UnaryNodeOperator::AddressOf:
132      OpName = "addressOf";
133      break;
134    case UnaryNodeOperator::MaybeAddressOf:
135      OpName = "maybeAddressOf";
136      break;
137    case UnaryNodeOperator::Describe:
138      OpName = "describe";
139      break;
140    }
141    return (OpName + "(\"" + Id + "\")").str();
142  }
143
144  Error eval(const MatchFinder::MatchResult &Match,
145             std::string *Result) const override {
146    // The `Describe` operation can be applied to any node, not just
147    // expressions, so it is handled here, separately.
148    if (Op == UnaryNodeOperator::Describe)
149      return printNode(Id, Match, Result);
150
151    const auto *E = Match.Nodes.getNodeAs<Expr>(Id);
152    if (E == nullptr)
153      return llvm::make_error<StringError>(errc::invalid_argument,
154                                           "Id not bound or not Expr: " + Id);
155    std::optional<std::string> Source;
156    switch (Op) {
157    case UnaryNodeOperator::Parens:
158      Source = tooling::buildParens(*E, *Match.Context);
159      break;
160    case UnaryNodeOperator::Deref:
161      Source = tooling::buildDereference(*E, *Match.Context);
162      break;
163    case UnaryNodeOperator::MaybeDeref:
164      if (E->getType()->isAnyPointerType() ||
165          tooling::isKnownPointerLikeType(E->getType(), *Match.Context)) {
166        // Strip off any operator->. This can only occur inside an actual arrow
167        // member access, so we treat it as equivalent to an actual object
168        // expression.
169        if (const auto *OpCall = dyn_cast<clang::CXXOperatorCallExpr>(E)) {
170          if (OpCall->getOperator() == clang::OO_Arrow &&
171              OpCall->getNumArgs() == 1) {
172            E = OpCall->getArg(0);
173          }
174        }
175        Source = tooling::buildDereference(*E, *Match.Context);
176        break;
177      }
178      *Result += tooling::getText(*E, *Match.Context);
179      return Error::success();
180    case UnaryNodeOperator::AddressOf:
181      Source = tooling::buildAddressOf(*E, *Match.Context);
182      break;
183    case UnaryNodeOperator::MaybeAddressOf:
184      if (E->getType()->isAnyPointerType() ||
185          tooling::isKnownPointerLikeType(E->getType(), *Match.Context)) {
186        // Strip off any operator->. This can only occur inside an actual arrow
187        // member access, so we treat it as equivalent to an actual object
188        // expression.
189        if (const auto *OpCall = dyn_cast<clang::CXXOperatorCallExpr>(E)) {
190          if (OpCall->getOperator() == clang::OO_Arrow &&
191              OpCall->getNumArgs() == 1) {
192            E = OpCall->getArg(0);
193          }
194        }
195        *Result += tooling::getText(*E, *Match.Context);
196        return Error::success();
197      }
198      Source = tooling::buildAddressOf(*E, *Match.Context);
199      break;
200    case UnaryNodeOperator::Describe:
201      llvm_unreachable("This case is handled at the start of the function");
202    }
203    if (!Source)
204      return llvm::make_error<StringError>(
205          errc::invalid_argument,
206          "Could not construct expression source from ID: " + Id);
207    *Result += *Source;
208    return Error::success();
209  }
210};
211
212// The fragment of code corresponding to the selected range.
213class SelectorStencil : public StencilInterface {
214  RangeSelector Selector;
215
216public:
217  explicit SelectorStencil(RangeSelector S) : Selector(std::move(S)) {}
218
219  std::string toString() const override { return "selection(...)"; }
220
221  Error eval(const MatchFinder::MatchResult &Match,
222             std::string *Result) const override {
223    auto RawRange = Selector(Match);
224    if (!RawRange)
225      return RawRange.takeError();
226    CharSourceRange Range = Lexer::makeFileCharRange(
227        *RawRange, *Match.SourceManager, Match.Context->getLangOpts());
228    if (Range.isInvalid()) {
229      // Validate the original range to attempt to get a meaningful error
230      // message. If it's valid, then something else is the cause and we just
231      // return the generic failure message.
232      if (auto Err =
233              tooling::validateEditRange(*RawRange, *Match.SourceManager))
234        return handleErrors(std::move(Err), [](std::unique_ptr<StringError> E) {
235          assert(E->convertToErrorCode() ==
236                     llvm::make_error_code(errc::invalid_argument) &&
237                 "Validation errors must carry the invalid_argument code");
238          return llvm::createStringError(
239              errc::invalid_argument,
240              "selected range could not be resolved to a valid source range; " +
241                  E->getMessage());
242        });
243      return llvm::createStringError(
244          errc::invalid_argument,
245          "selected range could not be resolved to a valid source range");
246    }
247    // Validate `Range`, because `makeFileCharRange` accepts some ranges that
248    // `validateEditRange` rejects.
249    if (auto Err = tooling::validateEditRange(Range, *Match.SourceManager))
250      return joinErrors(
251          llvm::createStringError(errc::invalid_argument,
252                                  "selected range is not valid for editing"),
253          std::move(Err));
254    *Result += tooling::getText(Range, *Match.Context);
255    return Error::success();
256  }
257};
258
259// A stencil operation to build a member access `e.m` or `e->m`, as appropriate.
260class AccessStencil : public StencilInterface {
261  std::string BaseId;
262  Stencil Member;
263
264public:
265  AccessStencil(StringRef BaseId, Stencil Member)
266      : BaseId(std::string(BaseId)), Member(std::move(Member)) {}
267
268  std::string toString() const override {
269    return (llvm::Twine("access(\"") + BaseId + "\", " + Member->toString() +
270            ")")
271        .str();
272  }
273
274  Error eval(const MatchFinder::MatchResult &Match,
275             std::string *Result) const override {
276    const auto *E = Match.Nodes.getNodeAs<Expr>(BaseId);
277    if (E == nullptr)
278      return llvm::make_error<StringError>(errc::invalid_argument,
279                                           "Id not bound: " + BaseId);
280    std::optional<std::string> S = tooling::buildAccess(*E, *Match.Context);
281    if (!S)
282      return llvm::make_error<StringError>(
283          errc::invalid_argument,
284          "Could not construct object text from ID: " + BaseId);
285    *Result += *S;
286    return Member->eval(Match, Result);
287  }
288};
289
290class IfBoundStencil : public StencilInterface {
291  std::string Id;
292  Stencil TrueStencil;
293  Stencil FalseStencil;
294
295public:
296  IfBoundStencil(StringRef Id, Stencil TrueStencil, Stencil FalseStencil)
297      : Id(std::string(Id)), TrueStencil(std::move(TrueStencil)),
298        FalseStencil(std::move(FalseStencil)) {}
299
300  std::string toString() const override {
301    return (llvm::Twine("ifBound(\"") + Id + "\", " + TrueStencil->toString() +
302            ", " + FalseStencil->toString() + ")")
303        .str();
304  }
305
306  Error eval(const MatchFinder::MatchResult &Match,
307             std::string *Result) const override {
308    auto &M = Match.Nodes.getMap();
309    return (M.find(Id) != M.end() ? TrueStencil : FalseStencil)
310        ->eval(Match, Result);
311  }
312};
313
314class SelectBoundStencil : public clang::transformer::StencilInterface {
315  static bool containsNoNullStencils(
316      const std::vector<std::pair<std::string, Stencil>> &Cases) {
317    for (const auto &S : Cases)
318      if (S.second == nullptr)
319        return false;
320    return true;
321  }
322
323public:
324  SelectBoundStencil(std::vector<std::pair<std::string, Stencil>> Cases,
325                     Stencil Default)
326      : CaseStencils(std::move(Cases)), DefaultStencil(std::move(Default)) {
327    assert(containsNoNullStencils(CaseStencils) &&
328           "cases of selectBound may not be null");
329  }
330  ~SelectBoundStencil() override{};
331
332  llvm::Error eval(const MatchFinder::MatchResult &match,
333                   std::string *result) const override {
334    const BoundNodes::IDToNodeMap &NodeMap = match.Nodes.getMap();
335    for (const auto &S : CaseStencils) {
336      if (NodeMap.count(S.first) > 0) {
337        return S.second->eval(match, result);
338      }
339    }
340
341    if (DefaultStencil != nullptr) {
342      return DefaultStencil->eval(match, result);
343    }
344
345    llvm::SmallVector<llvm::StringRef, 2> CaseIDs;
346    CaseIDs.reserve(CaseStencils.size());
347    for (const auto &S : CaseStencils)
348      CaseIDs.emplace_back(S.first);
349
350    return llvm::createStringError(
351        errc::result_out_of_range,
352        llvm::Twine("selectBound failed: no cases bound and no default: {") +
353            llvm::join(CaseIDs, ", ") + "}");
354  }
355
356  std::string toString() const override {
357    std::string Buffer;
358    llvm::raw_string_ostream Stream(Buffer);
359    Stream << "selectBound({";
360    bool First = true;
361    for (const auto &S : CaseStencils) {
362      if (First)
363        First = false;
364      else
365        Stream << "}, ";
366      Stream << "{\"" << S.first << "\", " << S.second->toString();
367    }
368    Stream << "}}";
369    if (DefaultStencil != nullptr) {
370      Stream << ", " << DefaultStencil->toString();
371    }
372    Stream << ")";
373    return Stream.str();
374  }
375
376private:
377  std::vector<std::pair<std::string, Stencil>> CaseStencils;
378  Stencil DefaultStencil;
379};
380
381class SequenceStencil : public StencilInterface {
382  std::vector<Stencil> Stencils;
383
384public:
385  SequenceStencil(std::vector<Stencil> Stencils)
386      : Stencils(std::move(Stencils)) {}
387
388  std::string toString() const override {
389    llvm::SmallVector<std::string, 2> Parts;
390    Parts.reserve(Stencils.size());
391    for (const auto &S : Stencils)
392      Parts.push_back(S->toString());
393    return (llvm::Twine("seq(") + llvm::join(Parts, ", ") + ")").str();
394  }
395
396  Error eval(const MatchFinder::MatchResult &Match,
397             std::string *Result) const override {
398    for (const auto &S : Stencils)
399      if (auto Err = S->eval(Match, Result))
400        return Err;
401    return Error::success();
402  }
403};
404
405class RunStencil : public StencilInterface {
406  MatchConsumer<std::string> Consumer;
407
408public:
409  explicit RunStencil(MatchConsumer<std::string> C) : Consumer(std::move(C)) {}
410
411  std::string toString() const override { return "run(...)"; }
412
413  Error eval(const MatchFinder::MatchResult &Match,
414             std::string *Result) const override {
415
416    Expected<std::string> Value = Consumer(Match);
417    if (!Value)
418      return Value.takeError();
419    *Result += *Value;
420    return Error::success();
421  }
422};
423} // namespace
424
425Stencil transformer::detail::makeStencil(StringRef Text) {
426  return std::make_shared<RawTextStencil>(std::string(Text));
427}
428
429Stencil transformer::detail::makeStencil(RangeSelector Selector) {
430  return std::make_shared<SelectorStencil>(std::move(Selector));
431}
432
433Stencil transformer::dPrint(StringRef Id) {
434  return std::make_shared<DebugPrintNodeStencil>(std::string(Id));
435}
436
437Stencil transformer::expression(llvm::StringRef Id) {
438  return std::make_shared<UnaryOperationStencil>(UnaryNodeOperator::Parens,
439                                                 std::string(Id));
440}
441
442Stencil transformer::deref(llvm::StringRef ExprId) {
443  return std::make_shared<UnaryOperationStencil>(UnaryNodeOperator::Deref,
444                                                 std::string(ExprId));
445}
446
447Stencil transformer::maybeDeref(llvm::StringRef ExprId) {
448  return std::make_shared<UnaryOperationStencil>(UnaryNodeOperator::MaybeDeref,
449                                                 std::string(ExprId));
450}
451
452Stencil transformer::addressOf(llvm::StringRef ExprId) {
453  return std::make_shared<UnaryOperationStencil>(UnaryNodeOperator::AddressOf,
454                                                 std::string(ExprId));
455}
456
457Stencil transformer::maybeAddressOf(llvm::StringRef ExprId) {
458  return std::make_shared<UnaryOperationStencil>(
459      UnaryNodeOperator::MaybeAddressOf, std::string(ExprId));
460}
461
462Stencil transformer::describe(StringRef Id) {
463  return std::make_shared<UnaryOperationStencil>(UnaryNodeOperator::Describe,
464                                                 std::string(Id));
465}
466
467Stencil transformer::access(StringRef BaseId, Stencil Member) {
468  return std::make_shared<AccessStencil>(BaseId, std::move(Member));
469}
470
471Stencil transformer::ifBound(StringRef Id, Stencil TrueStencil,
472                             Stencil FalseStencil) {
473  return std::make_shared<IfBoundStencil>(Id, std::move(TrueStencil),
474                                          std::move(FalseStencil));
475}
476
477Stencil transformer::selectBound(
478    std::vector<std::pair<std::string, Stencil>> CaseStencils,
479    Stencil DefaultStencil) {
480  return std::make_shared<SelectBoundStencil>(std::move(CaseStencils),
481                                              std::move(DefaultStencil));
482}
483
484Stencil transformer::run(MatchConsumer<std::string> Fn) {
485  return std::make_shared<RunStencil>(std::move(Fn));
486}
487
488Stencil transformer::catVector(std::vector<Stencil> Parts) {
489  // Only one argument, so don't wrap in sequence.
490  if (Parts.size() == 1)
491    return std::move(Parts[0]);
492  return std::make_shared<SequenceStencil>(std::move(Parts));
493}
494