1//===- SymbolRewriter.cpp - Symbol Rewriter -------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// SymbolRewriter is a LLVM pass which can rewrite symbols transparently within
10// existing code.  It is implemented as a compiler pass and is configured via a
11// YAML configuration file.
12//
13// The YAML configuration file format is as follows:
14//
15// RewriteMapFile := RewriteDescriptors
16// RewriteDescriptors := RewriteDescriptor | RewriteDescriptors
17// RewriteDescriptor := RewriteDescriptorType ':' '{' RewriteDescriptorFields '}'
18// RewriteDescriptorFields := RewriteDescriptorField | RewriteDescriptorFields
19// RewriteDescriptorField := FieldIdentifier ':' FieldValue ','
20// RewriteDescriptorType := Identifier
21// FieldIdentifier := Identifier
22// FieldValue := Identifier
23// Identifier := [0-9a-zA-Z]+
24//
25// Currently, the following descriptor types are supported:
26//
27// - function:          (function rewriting)
28//      + Source        (original name of the function)
29//      + Target        (explicit transformation)
30//      + Transform     (pattern transformation)
31//      + Naked         (boolean, whether the function is undecorated)
32// - global variable:   (external linkage global variable rewriting)
33//      + Source        (original name of externally visible variable)
34//      + Target        (explicit transformation)
35//      + Transform     (pattern transformation)
36// - global alias:      (global alias rewriting)
37//      + Source        (original name of the aliased name)
38//      + Target        (explicit transformation)
39//      + Transform     (pattern transformation)
40//
41// Note that source and exactly one of [Target, Transform] must be provided
42//
43// New rewrite descriptors can be created.  Addding a new rewrite descriptor
44// involves:
45//
46//  a) extended the rewrite descriptor kind enumeration
47//     (<anonymous>::RewriteDescriptor::RewriteDescriptorType)
48//  b) implementing the new descriptor
49//     (c.f. <anonymous>::ExplicitRewriteFunctionDescriptor)
50//  c) extending the rewrite map parser
51//     (<anonymous>::RewriteMapParser::parseEntry)
52//
53//  Specify to rewrite the symbols using the `-rewrite-symbols` option, and
54//  specify the map file to use for the rewriting via the `-rewrite-map-file`
55//  option.
56//
57//===----------------------------------------------------------------------===//
58
59#include "llvm/Transforms/Utils/SymbolRewriter.h"
60#include "llvm/ADT/STLExtras.h"
61#include "llvm/ADT/SmallString.h"
62#include "llvm/ADT/StringRef.h"
63#include "llvm/ADT/ilist.h"
64#include "llvm/ADT/iterator_range.h"
65#include "llvm/IR/Comdat.h"
66#include "llvm/IR/Function.h"
67#include "llvm/IR/GlobalAlias.h"
68#include "llvm/IR/GlobalObject.h"
69#include "llvm/IR/GlobalVariable.h"
70#include "llvm/IR/Module.h"
71#include "llvm/IR/Value.h"
72#include "llvm/InitializePasses.h"
73#include "llvm/Pass.h"
74#include "llvm/Support/Casting.h"
75#include "llvm/Support/CommandLine.h"
76#include "llvm/Support/ErrorHandling.h"
77#include "llvm/Support/ErrorOr.h"
78#include "llvm/Support/MemoryBuffer.h"
79#include "llvm/Support/Regex.h"
80#include "llvm/Support/SourceMgr.h"
81#include "llvm/Support/YAMLParser.h"
82#include <memory>
83#include <string>
84#include <vector>
85
86using namespace llvm;
87using namespace SymbolRewriter;
88
89#define DEBUG_TYPE "symbol-rewriter"
90
91static cl::list<std::string> RewriteMapFiles("rewrite-map-file",
92                                             cl::desc("Symbol Rewrite Map"),
93                                             cl::value_desc("filename"),
94                                             cl::Hidden);
95
96static void rewriteComdat(Module &M, GlobalObject *GO,
97                          const std::string &Source,
98                          const std::string &Target) {
99  if (Comdat *CD = GO->getComdat()) {
100    auto &Comdats = M.getComdatSymbolTable();
101
102    Comdat *C = M.getOrInsertComdat(Target);
103    C->setSelectionKind(CD->getSelectionKind());
104    GO->setComdat(C);
105
106    Comdats.erase(Comdats.find(Source));
107  }
108}
109
110namespace {
111
112template <RewriteDescriptor::Type DT, typename ValueType,
113          ValueType *(Module::*Get)(StringRef) const>
114class ExplicitRewriteDescriptor : public RewriteDescriptor {
115public:
116  const std::string Source;
117  const std::string Target;
118
119  ExplicitRewriteDescriptor(StringRef S, StringRef T, const bool Naked)
120      : RewriteDescriptor(DT), Source(Naked ? StringRef("\01" + S.str()) : S),
121        Target(T) {}
122
123  bool performOnModule(Module &M) override;
124
125  static bool classof(const RewriteDescriptor *RD) {
126    return RD->getType() == DT;
127  }
128};
129
130} // end anonymous namespace
131
132template <RewriteDescriptor::Type DT, typename ValueType,
133          ValueType *(Module::*Get)(StringRef) const>
134bool ExplicitRewriteDescriptor<DT, ValueType, Get>::performOnModule(Module &M) {
135  bool Changed = false;
136  if (ValueType *S = (M.*Get)(Source)) {
137    if (GlobalObject *GO = dyn_cast<GlobalObject>(S))
138      rewriteComdat(M, GO, Source, Target);
139
140    if (Value *T = (M.*Get)(Target))
141      S->setValueName(T->getValueName());
142    else
143      S->setName(Target);
144
145    Changed = true;
146  }
147  return Changed;
148}
149
150namespace {
151
152template <RewriteDescriptor::Type DT, typename ValueType,
153          ValueType *(Module::*Get)(StringRef) const,
154          iterator_range<typename iplist<ValueType>::iterator>
155          (Module::*Iterator)()>
156class PatternRewriteDescriptor : public RewriteDescriptor {
157public:
158  const std::string Pattern;
159  const std::string Transform;
160
161  PatternRewriteDescriptor(StringRef P, StringRef T)
162    : RewriteDescriptor(DT), Pattern(P), Transform(T) { }
163
164  bool performOnModule(Module &M) override;
165
166  static bool classof(const RewriteDescriptor *RD) {
167    return RD->getType() == DT;
168  }
169};
170
171} // end anonymous namespace
172
173template <RewriteDescriptor::Type DT, typename ValueType,
174          ValueType *(Module::*Get)(StringRef) const,
175          iterator_range<typename iplist<ValueType>::iterator>
176          (Module::*Iterator)()>
177bool PatternRewriteDescriptor<DT, ValueType, Get, Iterator>::
178performOnModule(Module &M) {
179  bool Changed = false;
180  for (auto &C : (M.*Iterator)()) {
181    std::string Error;
182
183    std::string Name = Regex(Pattern).sub(Transform, C.getName(), &Error);
184    if (!Error.empty())
185      report_fatal_error("unable to transforn " + C.getName() + " in " +
186                         M.getModuleIdentifier() + ": " + Error);
187
188    if (C.getName() == Name)
189      continue;
190
191    if (GlobalObject *GO = dyn_cast<GlobalObject>(&C))
192      rewriteComdat(M, GO, C.getName(), Name);
193
194    if (Value *V = (M.*Get)(Name))
195      C.setValueName(V->getValueName());
196    else
197      C.setName(Name);
198
199    Changed = true;
200  }
201  return Changed;
202}
203
204namespace {
205
206/// Represents a rewrite for an explicitly named (function) symbol.  Both the
207/// source function name and target function name of the transformation are
208/// explicitly spelt out.
209using ExplicitRewriteFunctionDescriptor =
210    ExplicitRewriteDescriptor<RewriteDescriptor::Type::Function, Function,
211                              &Module::getFunction>;
212
213/// Represents a rewrite for an explicitly named (global variable) symbol.  Both
214/// the source variable name and target variable name are spelt out.  This
215/// applies only to module level variables.
216using ExplicitRewriteGlobalVariableDescriptor =
217    ExplicitRewriteDescriptor<RewriteDescriptor::Type::GlobalVariable,
218                              GlobalVariable, &Module::getGlobalVariable>;
219
220/// Represents a rewrite for an explicitly named global alias.  Both the source
221/// and target name are explicitly spelt out.
222using ExplicitRewriteNamedAliasDescriptor =
223    ExplicitRewriteDescriptor<RewriteDescriptor::Type::NamedAlias, GlobalAlias,
224                              &Module::getNamedAlias>;
225
226/// Represents a rewrite for a regular expression based pattern for functions.
227/// A pattern for the function name is provided and a transformation for that
228/// pattern to determine the target function name create the rewrite rule.
229using PatternRewriteFunctionDescriptor =
230    PatternRewriteDescriptor<RewriteDescriptor::Type::Function, Function,
231                             &Module::getFunction, &Module::functions>;
232
233/// Represents a rewrite for a global variable based upon a matching pattern.
234/// Each global variable matching the provided pattern will be transformed as
235/// described in the transformation pattern for the target.  Applies only to
236/// module level variables.
237using PatternRewriteGlobalVariableDescriptor =
238    PatternRewriteDescriptor<RewriteDescriptor::Type::GlobalVariable,
239                             GlobalVariable, &Module::getGlobalVariable,
240                             &Module::globals>;
241
242/// PatternRewriteNamedAliasDescriptor - represents a rewrite for global
243/// aliases which match a given pattern.  The provided transformation will be
244/// applied to each of the matching names.
245using PatternRewriteNamedAliasDescriptor =
246    PatternRewriteDescriptor<RewriteDescriptor::Type::NamedAlias, GlobalAlias,
247                             &Module::getNamedAlias, &Module::aliases>;
248
249} // end anonymous namespace
250
251bool RewriteMapParser::parse(const std::string &MapFile,
252                             RewriteDescriptorList *DL) {
253  ErrorOr<std::unique_ptr<MemoryBuffer>> Mapping =
254      MemoryBuffer::getFile(MapFile);
255
256  if (!Mapping)
257    report_fatal_error("unable to read rewrite map '" + MapFile + "': " +
258                       Mapping.getError().message());
259
260  if (!parse(*Mapping, DL))
261    report_fatal_error("unable to parse rewrite map '" + MapFile + "'");
262
263  return true;
264}
265
266bool RewriteMapParser::parse(std::unique_ptr<MemoryBuffer> &MapFile,
267                             RewriteDescriptorList *DL) {
268  SourceMgr SM;
269  yaml::Stream YS(MapFile->getBuffer(), SM);
270
271  for (auto &Document : YS) {
272    yaml::MappingNode *DescriptorList;
273
274    // ignore empty documents
275    if (isa<yaml::NullNode>(Document.getRoot()))
276      continue;
277
278    DescriptorList = dyn_cast<yaml::MappingNode>(Document.getRoot());
279    if (!DescriptorList) {
280      YS.printError(Document.getRoot(), "DescriptorList node must be a map");
281      return false;
282    }
283
284    for (auto &Descriptor : *DescriptorList)
285      if (!parseEntry(YS, Descriptor, DL))
286        return false;
287  }
288
289  return true;
290}
291
292bool RewriteMapParser::parseEntry(yaml::Stream &YS, yaml::KeyValueNode &Entry,
293                                  RewriteDescriptorList *DL) {
294  yaml::ScalarNode *Key;
295  yaml::MappingNode *Value;
296  SmallString<32> KeyStorage;
297  StringRef RewriteType;
298
299  Key = dyn_cast<yaml::ScalarNode>(Entry.getKey());
300  if (!Key) {
301    YS.printError(Entry.getKey(), "rewrite type must be a scalar");
302    return false;
303  }
304
305  Value = dyn_cast<yaml::MappingNode>(Entry.getValue());
306  if (!Value) {
307    YS.printError(Entry.getValue(), "rewrite descriptor must be a map");
308    return false;
309  }
310
311  RewriteType = Key->getValue(KeyStorage);
312  if (RewriteType.equals("function"))
313    return parseRewriteFunctionDescriptor(YS, Key, Value, DL);
314  else if (RewriteType.equals("global variable"))
315    return parseRewriteGlobalVariableDescriptor(YS, Key, Value, DL);
316  else if (RewriteType.equals("global alias"))
317    return parseRewriteGlobalAliasDescriptor(YS, Key, Value, DL);
318
319  YS.printError(Entry.getKey(), "unknown rewrite type");
320  return false;
321}
322
323bool RewriteMapParser::
324parseRewriteFunctionDescriptor(yaml::Stream &YS, yaml::ScalarNode *K,
325                               yaml::MappingNode *Descriptor,
326                               RewriteDescriptorList *DL) {
327  bool Naked = false;
328  std::string Source;
329  std::string Target;
330  std::string Transform;
331
332  for (auto &Field : *Descriptor) {
333    yaml::ScalarNode *Key;
334    yaml::ScalarNode *Value;
335    SmallString<32> KeyStorage;
336    SmallString<32> ValueStorage;
337    StringRef KeyValue;
338
339    Key = dyn_cast<yaml::ScalarNode>(Field.getKey());
340    if (!Key) {
341      YS.printError(Field.getKey(), "descriptor key must be a scalar");
342      return false;
343    }
344
345    Value = dyn_cast<yaml::ScalarNode>(Field.getValue());
346    if (!Value) {
347      YS.printError(Field.getValue(), "descriptor value must be a scalar");
348      return false;
349    }
350
351    KeyValue = Key->getValue(KeyStorage);
352    if (KeyValue.equals("source")) {
353      std::string Error;
354
355      Source = Value->getValue(ValueStorage);
356      if (!Regex(Source).isValid(Error)) {
357        YS.printError(Field.getKey(), "invalid regex: " + Error);
358        return false;
359      }
360    } else if (KeyValue.equals("target")) {
361      Target = Value->getValue(ValueStorage);
362    } else if (KeyValue.equals("transform")) {
363      Transform = Value->getValue(ValueStorage);
364    } else if (KeyValue.equals("naked")) {
365      std::string Undecorated;
366
367      Undecorated = Value->getValue(ValueStorage);
368      Naked = StringRef(Undecorated).lower() == "true" || Undecorated == "1";
369    } else {
370      YS.printError(Field.getKey(), "unknown key for function");
371      return false;
372    }
373  }
374
375  if (Transform.empty() == Target.empty()) {
376    YS.printError(Descriptor,
377                  "exactly one of transform or target must be specified");
378    return false;
379  }
380
381  // TODO see if there is a more elegant solution to selecting the rewrite
382  // descriptor type
383  if (!Target.empty())
384    DL->push_back(std::make_unique<ExplicitRewriteFunctionDescriptor>(
385        Source, Target, Naked));
386  else
387    DL->push_back(
388        std::make_unique<PatternRewriteFunctionDescriptor>(Source, Transform));
389
390  return true;
391}
392
393bool RewriteMapParser::
394parseRewriteGlobalVariableDescriptor(yaml::Stream &YS, yaml::ScalarNode *K,
395                                     yaml::MappingNode *Descriptor,
396                                     RewriteDescriptorList *DL) {
397  std::string Source;
398  std::string Target;
399  std::string Transform;
400
401  for (auto &Field : *Descriptor) {
402    yaml::ScalarNode *Key;
403    yaml::ScalarNode *Value;
404    SmallString<32> KeyStorage;
405    SmallString<32> ValueStorage;
406    StringRef KeyValue;
407
408    Key = dyn_cast<yaml::ScalarNode>(Field.getKey());
409    if (!Key) {
410      YS.printError(Field.getKey(), "descriptor Key must be a scalar");
411      return false;
412    }
413
414    Value = dyn_cast<yaml::ScalarNode>(Field.getValue());
415    if (!Value) {
416      YS.printError(Field.getValue(), "descriptor value must be a scalar");
417      return false;
418    }
419
420    KeyValue = Key->getValue(KeyStorage);
421    if (KeyValue.equals("source")) {
422      std::string Error;
423
424      Source = Value->getValue(ValueStorage);
425      if (!Regex(Source).isValid(Error)) {
426        YS.printError(Field.getKey(), "invalid regex: " + Error);
427        return false;
428      }
429    } else if (KeyValue.equals("target")) {
430      Target = Value->getValue(ValueStorage);
431    } else if (KeyValue.equals("transform")) {
432      Transform = Value->getValue(ValueStorage);
433    } else {
434      YS.printError(Field.getKey(), "unknown Key for Global Variable");
435      return false;
436    }
437  }
438
439  if (Transform.empty() == Target.empty()) {
440    YS.printError(Descriptor,
441                  "exactly one of transform or target must be specified");
442    return false;
443  }
444
445  if (!Target.empty())
446    DL->push_back(std::make_unique<ExplicitRewriteGlobalVariableDescriptor>(
447        Source, Target,
448        /*Naked*/ false));
449  else
450    DL->push_back(std::make_unique<PatternRewriteGlobalVariableDescriptor>(
451        Source, Transform));
452
453  return true;
454}
455
456bool RewriteMapParser::
457parseRewriteGlobalAliasDescriptor(yaml::Stream &YS, yaml::ScalarNode *K,
458                                  yaml::MappingNode *Descriptor,
459                                  RewriteDescriptorList *DL) {
460  std::string Source;
461  std::string Target;
462  std::string Transform;
463
464  for (auto &Field : *Descriptor) {
465    yaml::ScalarNode *Key;
466    yaml::ScalarNode *Value;
467    SmallString<32> KeyStorage;
468    SmallString<32> ValueStorage;
469    StringRef KeyValue;
470
471    Key = dyn_cast<yaml::ScalarNode>(Field.getKey());
472    if (!Key) {
473      YS.printError(Field.getKey(), "descriptor key must be a scalar");
474      return false;
475    }
476
477    Value = dyn_cast<yaml::ScalarNode>(Field.getValue());
478    if (!Value) {
479      YS.printError(Field.getValue(), "descriptor value must be a scalar");
480      return false;
481    }
482
483    KeyValue = Key->getValue(KeyStorage);
484    if (KeyValue.equals("source")) {
485      std::string Error;
486
487      Source = Value->getValue(ValueStorage);
488      if (!Regex(Source).isValid(Error)) {
489        YS.printError(Field.getKey(), "invalid regex: " + Error);
490        return false;
491      }
492    } else if (KeyValue.equals("target")) {
493      Target = Value->getValue(ValueStorage);
494    } else if (KeyValue.equals("transform")) {
495      Transform = Value->getValue(ValueStorage);
496    } else {
497      YS.printError(Field.getKey(), "unknown key for Global Alias");
498      return false;
499    }
500  }
501
502  if (Transform.empty() == Target.empty()) {
503    YS.printError(Descriptor,
504                  "exactly one of transform or target must be specified");
505    return false;
506  }
507
508  if (!Target.empty())
509    DL->push_back(std::make_unique<ExplicitRewriteNamedAliasDescriptor>(
510        Source, Target,
511        /*Naked*/ false));
512  else
513    DL->push_back(std::make_unique<PatternRewriteNamedAliasDescriptor>(
514        Source, Transform));
515
516  return true;
517}
518
519namespace {
520
521class RewriteSymbolsLegacyPass : public ModulePass {
522public:
523  static char ID; // Pass identification, replacement for typeid
524
525  RewriteSymbolsLegacyPass();
526  RewriteSymbolsLegacyPass(SymbolRewriter::RewriteDescriptorList &DL);
527
528  bool runOnModule(Module &M) override;
529
530private:
531  RewriteSymbolPass Impl;
532};
533
534} // end anonymous namespace
535
536char RewriteSymbolsLegacyPass::ID = 0;
537
538RewriteSymbolsLegacyPass::RewriteSymbolsLegacyPass() : ModulePass(ID) {
539  initializeRewriteSymbolsLegacyPassPass(*PassRegistry::getPassRegistry());
540}
541
542RewriteSymbolsLegacyPass::RewriteSymbolsLegacyPass(
543    SymbolRewriter::RewriteDescriptorList &DL)
544    : ModulePass(ID), Impl(DL) {}
545
546bool RewriteSymbolsLegacyPass::runOnModule(Module &M) {
547  return Impl.runImpl(M);
548}
549
550PreservedAnalyses RewriteSymbolPass::run(Module &M, ModuleAnalysisManager &AM) {
551  if (!runImpl(M))
552    return PreservedAnalyses::all();
553
554  return PreservedAnalyses::none();
555}
556
557bool RewriteSymbolPass::runImpl(Module &M) {
558  bool Changed;
559
560  Changed = false;
561  for (auto &Descriptor : Descriptors)
562    Changed |= Descriptor->performOnModule(M);
563
564  return Changed;
565}
566
567void RewriteSymbolPass::loadAndParseMapFiles() {
568  const std::vector<std::string> MapFiles(RewriteMapFiles);
569  SymbolRewriter::RewriteMapParser Parser;
570
571  for (const auto &MapFile : MapFiles)
572    Parser.parse(MapFile, &Descriptors);
573}
574
575INITIALIZE_PASS(RewriteSymbolsLegacyPass, "rewrite-symbols", "Rewrite Symbols",
576                false, false)
577
578ModulePass *llvm::createRewriteSymbolsPass() {
579  return new RewriteSymbolsLegacyPass();
580}
581
582ModulePass *
583llvm::createRewriteSymbolsPass(SymbolRewriter::RewriteDescriptorList &DL) {
584  return new RewriteSymbolsLegacyPass(DL);
585}
586