1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3
4import os
5import sys
6import json
7import filecmp
8import shutil
9import argparse
10
11class Generator(object):
12
13    implementationContent = ''
14
15    RefClades = {"DeclarationNameInfo",
16        "NestedNameSpecifierLoc",
17        "TemplateArgumentLoc",
18        "TypeLoc"}
19
20    def __init__(self, templateClasses):
21        self.templateClasses = templateClasses
22
23    def GeneratePrologue(self):
24
25        self.implementationContent += \
26            """
27/*===- Generated file -------------------------------------------*- C++ -*-===*\
28|*                                                                            *|
29|* Introspection of available AST node SourceLocations                        *|
30|*                                                                            *|
31|* Automatically generated file, do not edit!                                 *|
32|*                                                                            *|
33\*===----------------------------------------------------------------------===*/
34
35namespace clang {
36namespace tooling {
37
38using LocationAndString = SourceLocationMap::value_type;
39using RangeAndString = SourceRangeMap::value_type;
40
41bool NodeIntrospection::hasIntrospectionSupport() { return true; }
42
43struct RecursionPopper
44{
45    RecursionPopper(std::vector<clang::TypeLoc> &TypeLocRecursionGuard)
46    :  TLRG(TypeLocRecursionGuard)
47    {
48
49    }
50
51    ~RecursionPopper()
52    {
53    TLRG.pop_back();
54    }
55
56private:
57std::vector<clang::TypeLoc> &TLRG;
58};
59"""
60
61    def GenerateBaseGetLocationsDeclaration(self, CladeName):
62        InstanceDecoration = "*"
63        if CladeName in self.RefClades:
64            InstanceDecoration = "&"
65
66        self.implementationContent += \
67            """
68void GetLocationsImpl(SharedLocationCall const& Prefix,
69    clang::{0} const {1}Object, SourceLocationMap &Locs,
70    SourceRangeMap &Rngs,
71    std::vector<clang::TypeLoc> &TypeLocRecursionGuard);
72""".format(CladeName, InstanceDecoration)
73
74    def GenerateSrcLocMethod(self,
75            ClassName, ClassData, CreateLocalRecursionGuard):
76
77        NormalClassName = ClassName
78        RecursionGuardParam = ('' if CreateLocalRecursionGuard else \
79            ', std::vector<clang::TypeLoc>& TypeLocRecursionGuard')
80
81        if "templateParms" in ClassData:
82            TemplatePreamble = "template <typename "
83            ClassName += "<"
84            First = True
85            for TA in ClassData["templateParms"]:
86                if not First:
87                    ClassName += ", "
88                    TemplatePreamble += ", typename "
89
90                First = False
91                ClassName += TA
92                TemplatePreamble += TA
93
94            ClassName += ">"
95            TemplatePreamble += ">\n";
96            self.implementationContent += TemplatePreamble
97
98        self.implementationContent += \
99            """
100static void GetLocations{0}(SharedLocationCall const& Prefix,
101    clang::{1} const &Object,
102    SourceLocationMap &Locs, SourceRangeMap &Rngs {2})
103{{
104""".format(NormalClassName, ClassName, RecursionGuardParam)
105
106        if 'sourceLocations' in ClassData:
107            for locName in ClassData['sourceLocations']:
108                self.implementationContent += \
109                    """
110  Locs.insert(LocationAndString(Object.{0}(),
111    llvm::makeIntrusiveRefCnt<LocationCall>(Prefix, "{0}")));
112""".format(locName)
113
114            self.implementationContent += '\n'
115
116        if 'sourceRanges' in ClassData:
117            for rngName in ClassData['sourceRanges']:
118                self.implementationContent += \
119                    """
120  Rngs.insert(RangeAndString(Object.{0}(),
121    llvm::makeIntrusiveRefCnt<LocationCall>(Prefix, "{0}")));
122""".format(rngName)
123
124            self.implementationContent += '\n'
125
126        if 'typeLocs' in ClassData or 'typeSourceInfos' in ClassData \
127                or 'nestedNameLocs' in ClassData \
128                or 'declNameInfos' in ClassData:
129            if CreateLocalRecursionGuard:
130                self.implementationContent += \
131                    'std::vector<clang::TypeLoc> TypeLocRecursionGuard;\n'
132
133            self.implementationContent += '\n'
134
135            if 'typeLocs' in ClassData:
136                for typeLoc in ClassData['typeLocs']:
137
138                    self.implementationContent += \
139                        """
140              if (Object.{0}()) {{
141                GetLocationsImpl(
142                    llvm::makeIntrusiveRefCnt<LocationCall>(Prefix, "{0}"),
143                    Object.{0}(), Locs, Rngs, TypeLocRecursionGuard);
144                }}
145              """.format(typeLoc)
146
147            self.implementationContent += '\n'
148            if 'typeSourceInfos' in ClassData:
149                for tsi in ClassData['typeSourceInfos']:
150                    self.implementationContent += \
151                        """
152              if (Object.{0}()) {{
153                GetLocationsImpl(llvm::makeIntrusiveRefCnt<LocationCall>(
154                    llvm::makeIntrusiveRefCnt<LocationCall>(Prefix, "{0}",
155                        LocationCall::ReturnsPointer), "getTypeLoc"),
156                    Object.{0}()->getTypeLoc(), Locs, Rngs, TypeLocRecursionGuard);
157                    }}
158              """.format(tsi)
159
160                self.implementationContent += '\n'
161
162            if 'nestedNameLocs' in ClassData:
163                for NN in ClassData['nestedNameLocs']:
164                    self.implementationContent += \
165                        """
166              if (Object.{0}())
167                GetLocationsImpl(
168                    llvm::makeIntrusiveRefCnt<LocationCall>(Prefix, "{0}"),
169                    Object.{0}(), Locs, Rngs, TypeLocRecursionGuard);
170              """.format(NN)
171
172            if 'declNameInfos' in ClassData:
173                for declName in ClassData['declNameInfos']:
174
175                    self.implementationContent += \
176                        """
177                      GetLocationsImpl(
178                          llvm::makeIntrusiveRefCnt<LocationCall>(Prefix, "{0}"),
179                          Object.{0}(), Locs, Rngs, TypeLocRecursionGuard);
180                      """.format(declName)
181
182        self.implementationContent += '}\n'
183
184    def GenerateFiles(self, OutputFile):
185        with open(os.path.join(os.getcwd(),
186                  OutputFile), 'w') as f:
187            f.write(self.implementationContent)
188
189    def GenerateBaseGetLocationsFunction(self, ASTClassNames,
190            ClassEntries, CladeName, InheritanceMap,
191            CreateLocalRecursionGuard):
192
193        MethodReturnType = 'NodeLocationAccessors'
194        InstanceDecoration = "*"
195        if CladeName in self.RefClades:
196            InstanceDecoration = "&"
197
198        Signature = \
199            'GetLocations(clang::{0} const {1}Object)'.format(
200                CladeName, InstanceDecoration)
201        ImplSignature = \
202            """
203    GetLocationsImpl(SharedLocationCall const& Prefix,
204        clang::{0} const {1}Object, SourceLocationMap &Locs,
205        SourceRangeMap &Rngs,
206        std::vector<clang::TypeLoc> &TypeLocRecursionGuard)
207    """.format(CladeName, InstanceDecoration)
208
209        self.implementationContent += 'void {0} {{ '.format(ImplSignature)
210
211        if CladeName == "TypeLoc":
212            self.implementationContent += 'if (Object.isNull()) return;'
213
214            self.implementationContent += \
215                """
216            if (llvm::find(TypeLocRecursionGuard, Object) != TypeLocRecursionGuard.end())
217              return;
218            TypeLocRecursionGuard.push_back(Object);
219            RecursionPopper RAII(TypeLocRecursionGuard);
220                """
221
222        RecursionGuardParam = ''
223        if not CreateLocalRecursionGuard:
224            RecursionGuardParam = ', TypeLocRecursionGuard'
225
226        ArgPrefix = '*'
227        if CladeName in self.RefClades:
228            ArgPrefix = ''
229        self.implementationContent += \
230            'GetLocations{0}(Prefix, {1}Object, Locs, Rngs {2});'.format(
231                CladeName, ArgPrefix, RecursionGuardParam)
232
233        if CladeName == "TypeLoc":
234            self.implementationContent += \
235                '''
236        if (auto QTL = Object.getAs<clang::QualifiedTypeLoc>()) {
237            auto Dequalified = QTL.getNextTypeLoc();
238            return GetLocationsImpl(llvm::makeIntrusiveRefCnt<LocationCall>(Prefix, "getNextTypeLoc"),
239                                Dequalified,
240                                Locs,
241                                Rngs,
242                                TypeLocRecursionGuard);
243        }'''
244
245        for ASTClassName in ASTClassNames:
246            if ASTClassName in self.templateClasses:
247                continue
248            if ASTClassName == CladeName:
249                continue
250            if CladeName != "TypeLoc":
251                self.implementationContent += \
252                """
253if (auto Derived = llvm::dyn_cast<clang::{0}>(Object)) {{
254  GetLocations{0}(Prefix, *Derived, Locs, Rngs {1});
255}}
256""".format(ASTClassName, RecursionGuardParam)
257                continue
258
259            self.GenerateBaseTypeLocVisit(ASTClassName, ClassEntries,
260                RecursionGuardParam, InheritanceMap)
261
262        self.implementationContent += '}'
263
264        self.implementationContent += \
265            """
266{0} NodeIntrospection::{1} {{
267  NodeLocationAccessors Result;
268  SharedLocationCall Prefix;
269  std::vector<clang::TypeLoc> TypeLocRecursionGuard;
270
271  GetLocationsImpl(Prefix, Object, Result.LocationAccessors,
272                   Result.RangeAccessors, TypeLocRecursionGuard);
273""".format(MethodReturnType, Signature)
274
275        self.implementationContent += 'return Result; }'
276
277    def GenerateBaseTypeLocVisit(self, ASTClassName, ClassEntries,
278            RecursionGuardParam, InheritanceMap):
279        CallPrefix = 'Prefix'
280        if ASTClassName != 'TypeLoc':
281            CallPrefix = \
282                '''llvm::makeIntrusiveRefCnt<LocationCall>(Prefix,
283                    "getAs<clang::{0}>", LocationCall::IsCast)
284                '''.format(ASTClassName)
285
286        if ASTClassName in ClassEntries:
287
288            self.implementationContent += \
289            """
290            if (auto ConcreteTL = Object.getAs<clang::{0}>())
291              GetLocations{1}({2}, ConcreteTL, Locs, Rngs {3});
292            """.format(ASTClassName, ASTClassName,
293                       CallPrefix, RecursionGuardParam)
294
295        if ASTClassName in InheritanceMap:
296            for baseTemplate in self.templateClasses:
297                if baseTemplate in InheritanceMap[ASTClassName]:
298                    self.implementationContent += \
299                    """
300    if (auto ConcreteTL = Object.getAs<clang::{0}>())
301      GetLocations{1}({2}, ConcreteTL, Locs, Rngs {3});
302    """.format(InheritanceMap[ASTClassName], baseTemplate,
303            CallPrefix, RecursionGuardParam)
304
305
306    def GenerateDynNodeVisitor(self, CladeNames):
307        MethodReturnType = 'NodeLocationAccessors'
308
309        Signature = \
310            'GetLocations(clang::DynTypedNode const &Node)'
311
312        self.implementationContent += MethodReturnType \
313            + ' NodeIntrospection::' + Signature + '{'
314
315        for CladeName in CladeNames:
316            if CladeName == "DeclarationNameInfo":
317                continue
318            self.implementationContent += \
319                """
320    if (const auto *N = Node.get<{0}>())
321    """.format(CladeName)
322            ArgPrefix = ""
323            if CladeName in self.RefClades:
324                ArgPrefix = "*"
325            self.implementationContent += \
326            """
327      return GetLocations({0}const_cast<{1} *>(N));""".format(ArgPrefix, CladeName)
328
329        self.implementationContent += '\nreturn {}; }'
330
331    def GenerateEpilogue(self):
332
333        self.implementationContent += '''
334  }
335}
336'''
337
338def main():
339
340    parser = argparse.ArgumentParser()
341    parser.add_argument('--json-input-path',
342                      help='Read API description from FILE', metavar='FILE')
343    parser.add_argument('--output-file', help='Generate output in FILEPATH',
344                      metavar='FILEPATH')
345    parser.add_argument('--use-empty-implementation',
346                      help='Generate empty implementation',
347                      action="store", type=int)
348    parser.add_argument('--empty-implementation',
349                      help='Copy empty implementation from FILEPATH',
350                      action="store", metavar='FILEPATH')
351
352    options = parser.parse_args()
353
354    use_empty_implementation = options.use_empty_implementation
355
356    if (not use_empty_implementation
357            and not os.path.exists(options.json_input_path)):
358        use_empty_implementation = True
359
360    if not use_empty_implementation:
361        with open(options.json_input_path) as f:
362            jsonData = json.load(f)
363
364        if not 'classesInClade' in jsonData or not jsonData["classesInClade"]:
365            use_empty_implementation = True
366
367    if use_empty_implementation:
368        if not os.path.exists(options.output_file) or \
369                not filecmp.cmp(options.empty_implementation, options.output_file):
370            shutil.copyfile(options.empty_implementation, options.output_file)
371        sys.exit(0)
372
373    templateClasses = []
374    for (ClassName, ClassAccessors) in jsonData['classEntries'].items():
375        if "templateParms" in ClassAccessors:
376            templateClasses.append(ClassName)
377
378    g = Generator(templateClasses)
379
380    g.GeneratePrologue()
381
382    for (CladeName, ClassNameData) in jsonData['classesInClade'].items():
383        g.GenerateBaseGetLocationsDeclaration(CladeName)
384
385    def getCladeName(ClassName):
386      for (CladeName, ClassNameData) in jsonData['classesInClade'].items():
387        if ClassName in ClassNameData:
388          return CladeName
389
390    for (ClassName, ClassAccessors) in jsonData['classEntries'].items():
391        cladeName = getCladeName(ClassName)
392        g.GenerateSrcLocMethod(
393            ClassName, ClassAccessors,
394            cladeName not in Generator.RefClades)
395
396    for (CladeName, ClassNameData) in jsonData['classesInClade'].items():
397        g.GenerateBaseGetLocationsFunction(
398            ClassNameData,
399            jsonData['classEntries'],
400            CladeName,
401            jsonData["classInheritance"],
402            CladeName not in Generator.RefClades)
403
404    g.GenerateDynNodeVisitor(jsonData['classesInClade'].keys())
405
406    g.GenerateEpilogue()
407
408    g.GenerateFiles(options.output_file)
409
410if __name__ == '__main__':
411    main()
412