1//===--- HLSLExternalSemaSource.cpp - HLSL Sema Source --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9//
10//===----------------------------------------------------------------------===//
11
12#include "clang/Sema/HLSLExternalSemaSource.h"
13#include "clang/AST/ASTContext.h"
14#include "clang/AST/Attr.h"
15#include "clang/AST/DeclCXX.h"
16#include "clang/Basic/AttrKinds.h"
17#include "clang/Basic/HLSLRuntime.h"
18#include "clang/Sema/Lookup.h"
19#include "clang/Sema/Sema.h"
20#include "llvm/Frontend/HLSL/HLSLResource.h"
21
22#include <functional>
23
24using namespace clang;
25using namespace llvm::hlsl;
26
27namespace {
28
29struct TemplateParameterListBuilder;
30
31struct BuiltinTypeDeclBuilder {
32  CXXRecordDecl *Record = nullptr;
33  ClassTemplateDecl *Template = nullptr;
34  ClassTemplateDecl *PrevTemplate = nullptr;
35  NamespaceDecl *HLSLNamespace = nullptr;
36  llvm::StringMap<FieldDecl *> Fields;
37
38  BuiltinTypeDeclBuilder(CXXRecordDecl *R) : Record(R) {
39    Record->startDefinition();
40    Template = Record->getDescribedClassTemplate();
41  }
42
43  BuiltinTypeDeclBuilder(Sema &S, NamespaceDecl *Namespace, StringRef Name)
44      : HLSLNamespace(Namespace) {
45    ASTContext &AST = S.getASTContext();
46    IdentifierInfo &II = AST.Idents.get(Name, tok::TokenKind::identifier);
47
48    LookupResult Result(S, &II, SourceLocation(), Sema::LookupTagName);
49    CXXRecordDecl *PrevDecl = nullptr;
50    if (S.LookupQualifiedName(Result, HLSLNamespace)) {
51      NamedDecl *Found = Result.getFoundDecl();
52      if (auto *TD = dyn_cast<ClassTemplateDecl>(Found)) {
53        PrevDecl = TD->getTemplatedDecl();
54        PrevTemplate = TD;
55      } else
56        PrevDecl = dyn_cast<CXXRecordDecl>(Found);
57      assert(PrevDecl && "Unexpected lookup result type.");
58    }
59
60    if (PrevDecl && PrevDecl->isCompleteDefinition()) {
61      Record = PrevDecl;
62      return;
63    }
64
65    Record = CXXRecordDecl::Create(AST, TagDecl::TagKind::TTK_Class,
66                                   HLSLNamespace, SourceLocation(),
67                                   SourceLocation(), &II, PrevDecl, true);
68    Record->setImplicit(true);
69    Record->setLexicalDeclContext(HLSLNamespace);
70    Record->setHasExternalLexicalStorage();
71
72    // Don't let anyone derive from built-in types.
73    Record->addAttr(FinalAttr::CreateImplicit(AST, SourceRange(),
74                                              AttributeCommonInfo::AS_Keyword,
75                                              FinalAttr::Keyword_final));
76  }
77
78  ~BuiltinTypeDeclBuilder() {
79    if (HLSLNamespace && !Template && Record->getDeclContext() == HLSLNamespace)
80      HLSLNamespace->addDecl(Record);
81  }
82
83  BuiltinTypeDeclBuilder &
84  addMemberVariable(StringRef Name, QualType Type,
85                    AccessSpecifier Access = AccessSpecifier::AS_private) {
86    if (Record->isCompleteDefinition())
87      return *this;
88    assert(Record->isBeingDefined() &&
89           "Definition must be started before adding members!");
90    ASTContext &AST = Record->getASTContext();
91
92    IdentifierInfo &II = AST.Idents.get(Name, tok::TokenKind::identifier);
93    TypeSourceInfo *MemTySource =
94        AST.getTrivialTypeSourceInfo(Type, SourceLocation());
95    auto *Field = FieldDecl::Create(
96        AST, Record, SourceLocation(), SourceLocation(), &II, Type, MemTySource,
97        nullptr, false, InClassInitStyle::ICIS_NoInit);
98    Field->setAccess(Access);
99    Field->setImplicit(true);
100    Record->addDecl(Field);
101    Fields[Name] = Field;
102    return *this;
103  }
104
105  BuiltinTypeDeclBuilder &
106  addHandleMember(AccessSpecifier Access = AccessSpecifier::AS_private) {
107    if (Record->isCompleteDefinition())
108      return *this;
109    QualType Ty = Record->getASTContext().VoidPtrTy;
110    if (Template) {
111      if (const auto *TTD = dyn_cast<TemplateTypeParmDecl>(
112              Template->getTemplateParameters()->getParam(0)))
113        Ty = Record->getASTContext().getPointerType(
114            QualType(TTD->getTypeForDecl(), 0));
115    }
116    return addMemberVariable("h", Ty, Access);
117  }
118
119  BuiltinTypeDeclBuilder &
120  annotateResourceClass(HLSLResourceAttr::ResourceClass RC,
121                        HLSLResourceAttr::ResourceKind RK) {
122    if (Record->isCompleteDefinition())
123      return *this;
124    Record->addAttr(
125        HLSLResourceAttr::CreateImplicit(Record->getASTContext(), RC, RK));
126    return *this;
127  }
128
129  static DeclRefExpr *lookupBuiltinFunction(ASTContext &AST, Sema &S,
130                                            StringRef Name) {
131    CXXScopeSpec SS;
132    IdentifierInfo &II = AST.Idents.get(Name, tok::TokenKind::identifier);
133    DeclarationNameInfo NameInfo =
134        DeclarationNameInfo(DeclarationName(&II), SourceLocation());
135    LookupResult R(S, NameInfo, Sema::LookupOrdinaryName);
136    S.LookupParsedName(R, S.getCurScope(), &SS, false);
137    assert(R.isSingleResult() &&
138           "Since this is a builtin it should always resolve!");
139    auto *VD = cast<ValueDecl>(R.getFoundDecl());
140    QualType Ty = VD->getType();
141    return DeclRefExpr::Create(AST, NestedNameSpecifierLoc(), SourceLocation(),
142                               VD, false, NameInfo, Ty, VK_PRValue);
143  }
144
145  static Expr *emitResourceClassExpr(ASTContext &AST, ResourceClass RC) {
146    return IntegerLiteral::Create(
147        AST,
148        llvm::APInt(AST.getIntWidth(AST.UnsignedCharTy),
149                    static_cast<uint8_t>(RC)),
150        AST.UnsignedCharTy, SourceLocation());
151  }
152
153  BuiltinTypeDeclBuilder &addDefaultHandleConstructor(Sema &S,
154                                                      ResourceClass RC) {
155    if (Record->isCompleteDefinition())
156      return *this;
157    ASTContext &AST = Record->getASTContext();
158
159    QualType ConstructorType =
160        AST.getFunctionType(AST.VoidTy, {}, FunctionProtoType::ExtProtoInfo());
161
162    CanQualType CanTy = Record->getTypeForDecl()->getCanonicalTypeUnqualified();
163    DeclarationName Name = AST.DeclarationNames.getCXXConstructorName(CanTy);
164    CXXConstructorDecl *Constructor = CXXConstructorDecl::Create(
165        AST, Record, SourceLocation(),
166        DeclarationNameInfo(Name, SourceLocation()), ConstructorType,
167        AST.getTrivialTypeSourceInfo(ConstructorType, SourceLocation()),
168        ExplicitSpecifier(), false, true, false,
169        ConstexprSpecKind::Unspecified);
170
171    DeclRefExpr *Fn =
172        lookupBuiltinFunction(AST, S, "__builtin_hlsl_create_handle");
173
174    Expr *RCExpr = emitResourceClassExpr(AST, RC);
175    Expr *Call = CallExpr::Create(AST, Fn, {RCExpr}, AST.VoidPtrTy, VK_PRValue,
176                                  SourceLocation(), FPOptionsOverride());
177
178    CXXThisExpr *This = new (AST) CXXThisExpr(
179        SourceLocation(),
180        Constructor->getThisType().getTypePtr()->getPointeeType(), true);
181    This->setValueKind(ExprValueKind::VK_LValue);
182    Expr *Handle = MemberExpr::CreateImplicit(AST, This, false, Fields["h"],
183                                              Fields["h"]->getType(), VK_LValue,
184                                              OK_Ordinary);
185
186    // If the handle isn't a void pointer, cast the builtin result to the
187    // correct type.
188    if (Handle->getType().getCanonicalType() != AST.VoidPtrTy) {
189      Call = CXXStaticCastExpr::Create(
190          AST, Handle->getType(), VK_PRValue, CK_Dependent, Call, nullptr,
191          AST.getTrivialTypeSourceInfo(Handle->getType(), SourceLocation()),
192          FPOptionsOverride(), SourceLocation(), SourceLocation(),
193          SourceRange());
194    }
195
196    BinaryOperator *Assign = BinaryOperator::Create(
197        AST, Handle, Call, BO_Assign, Handle->getType(), VK_LValue, OK_Ordinary,
198        SourceLocation(), FPOptionsOverride());
199
200    Constructor->setBody(
201        CompoundStmt::Create(AST, {Assign}, FPOptionsOverride(),
202                             SourceLocation(), SourceLocation()));
203    Constructor->setAccess(AccessSpecifier::AS_public);
204    Record->addDecl(Constructor);
205    return *this;
206  }
207
208  BuiltinTypeDeclBuilder &addArraySubscriptOperators() {
209    if (Record->isCompleteDefinition())
210      return *this;
211    addArraySubscriptOperator(true);
212    addArraySubscriptOperator(false);
213    return *this;
214  }
215
216  BuiltinTypeDeclBuilder &addArraySubscriptOperator(bool IsConst) {
217    if (Record->isCompleteDefinition())
218      return *this;
219    assert(Fields.count("h") > 0 &&
220           "Subscript operator must be added after the handle.");
221
222    FieldDecl *Handle = Fields["h"];
223    ASTContext &AST = Record->getASTContext();
224
225    assert(Handle->getType().getCanonicalType() != AST.VoidPtrTy &&
226           "Not yet supported for void pointer handles.");
227
228    QualType ElemTy =
229        QualType(Handle->getType()->getPointeeOrArrayElementType(), 0);
230    QualType ReturnTy = ElemTy;
231
232    FunctionProtoType::ExtProtoInfo ExtInfo;
233
234    // Subscript operators return references to elements, const makes the
235    // reference and method const so that the underlying data is not mutable.
236    ReturnTy = AST.getLValueReferenceType(ReturnTy);
237    if (IsConst) {
238      ExtInfo.TypeQuals.addConst();
239      ReturnTy.addConst();
240    }
241
242    QualType MethodTy =
243        AST.getFunctionType(ReturnTy, {AST.UnsignedIntTy}, ExtInfo);
244    auto *TSInfo = AST.getTrivialTypeSourceInfo(MethodTy, SourceLocation());
245    auto *MethodDecl = CXXMethodDecl::Create(
246        AST, Record, SourceLocation(),
247        DeclarationNameInfo(
248            AST.DeclarationNames.getCXXOperatorName(OO_Subscript),
249            SourceLocation()),
250        MethodTy, TSInfo, SC_None, false, false, ConstexprSpecKind::Unspecified,
251        SourceLocation());
252
253    IdentifierInfo &II = AST.Idents.get("Idx", tok::TokenKind::identifier);
254    auto *IdxParam = ParmVarDecl::Create(
255        AST, MethodDecl->getDeclContext(), SourceLocation(), SourceLocation(),
256        &II, AST.UnsignedIntTy,
257        AST.getTrivialTypeSourceInfo(AST.UnsignedIntTy, SourceLocation()),
258        SC_None, nullptr);
259    MethodDecl->setParams({IdxParam});
260
261    // Also add the parameter to the function prototype.
262    auto FnProtoLoc = TSInfo->getTypeLoc().getAs<FunctionProtoTypeLoc>();
263    FnProtoLoc.setParam(0, IdxParam);
264
265    auto *This = new (AST) CXXThisExpr(
266        SourceLocation(),
267        MethodDecl->getThisType().getTypePtr()->getPointeeType(), true);
268    This->setValueKind(ExprValueKind::VK_LValue);
269    auto *HandleAccess = MemberExpr::CreateImplicit(
270        AST, This, false, Handle, Handle->getType(), VK_LValue, OK_Ordinary);
271
272    auto *IndexExpr = DeclRefExpr::Create(
273        AST, NestedNameSpecifierLoc(), SourceLocation(), IdxParam, false,
274        DeclarationNameInfo(IdxParam->getDeclName(), SourceLocation()),
275        AST.UnsignedIntTy, VK_PRValue);
276
277    auto *Array =
278        new (AST) ArraySubscriptExpr(HandleAccess, IndexExpr, ElemTy, VK_LValue,
279                                     OK_Ordinary, SourceLocation());
280
281    auto *Return = ReturnStmt::Create(AST, SourceLocation(), Array, nullptr);
282
283    MethodDecl->setBody(CompoundStmt::Create(AST, {Return}, FPOptionsOverride(),
284                                             SourceLocation(),
285                                             SourceLocation()));
286    MethodDecl->setLexicalDeclContext(Record);
287    MethodDecl->setAccess(AccessSpecifier::AS_public);
288    MethodDecl->addAttr(AlwaysInlineAttr::CreateImplicit(
289        AST, SourceRange(), AttributeCommonInfo::AS_Keyword,
290        AlwaysInlineAttr::CXX11_clang_always_inline));
291    Record->addDecl(MethodDecl);
292
293    return *this;
294  }
295
296  BuiltinTypeDeclBuilder &startDefinition() {
297    if (Record->isCompleteDefinition())
298      return *this;
299    Record->startDefinition();
300    return *this;
301  }
302
303  BuiltinTypeDeclBuilder &completeDefinition() {
304    if (Record->isCompleteDefinition())
305      return *this;
306    assert(Record->isBeingDefined() &&
307           "Definition must be started before completing it.");
308
309    Record->completeDefinition();
310    return *this;
311  }
312
313  TemplateParameterListBuilder addTemplateArgumentList();
314};
315
316struct TemplateParameterListBuilder {
317  BuiltinTypeDeclBuilder &Builder;
318  ASTContext &AST;
319  llvm::SmallVector<NamedDecl *> Params;
320
321  TemplateParameterListBuilder(BuiltinTypeDeclBuilder &RB)
322      : Builder(RB), AST(RB.Record->getASTContext()) {}
323
324  ~TemplateParameterListBuilder() { finalizeTemplateArgs(); }
325
326  TemplateParameterListBuilder &
327  addTypeParameter(StringRef Name, QualType DefaultValue = QualType()) {
328    if (Builder.Record->isCompleteDefinition())
329      return *this;
330    unsigned Position = static_cast<unsigned>(Params.size());
331    auto *Decl = TemplateTypeParmDecl::Create(
332        AST, Builder.Record->getDeclContext(), SourceLocation(),
333        SourceLocation(), /* TemplateDepth */ 0, Position,
334        &AST.Idents.get(Name, tok::TokenKind::identifier), /* Typename */ false,
335        /* ParameterPack */ false);
336    if (!DefaultValue.isNull())
337      Decl->setDefaultArgument(AST.getTrivialTypeSourceInfo(DefaultValue));
338
339    Params.emplace_back(Decl);
340    return *this;
341  }
342
343  BuiltinTypeDeclBuilder &finalizeTemplateArgs() {
344    if (Params.empty())
345      return Builder;
346    auto *ParamList =
347        TemplateParameterList::Create(AST, SourceLocation(), SourceLocation(),
348                                      Params, SourceLocation(), nullptr);
349    Builder.Template = ClassTemplateDecl::Create(
350        AST, Builder.Record->getDeclContext(), SourceLocation(),
351        DeclarationName(Builder.Record->getIdentifier()), ParamList,
352        Builder.Record);
353    Builder.Record->setDescribedClassTemplate(Builder.Template);
354    Builder.Template->setImplicit(true);
355    Builder.Template->setLexicalDeclContext(Builder.Record->getDeclContext());
356    // NOTE: setPreviousDecl before addDecl so new decl replace old decl when
357    // make visible.
358    Builder.Template->setPreviousDecl(Builder.PrevTemplate);
359    Builder.Record->getDeclContext()->addDecl(Builder.Template);
360    Params.clear();
361
362    QualType T = Builder.Template->getInjectedClassNameSpecialization();
363    T = AST.getInjectedClassNameType(Builder.Record, T);
364
365    return Builder;
366  }
367};
368
369TemplateParameterListBuilder BuiltinTypeDeclBuilder::addTemplateArgumentList() {
370  return TemplateParameterListBuilder(*this);
371}
372} // namespace
373
374HLSLExternalSemaSource::~HLSLExternalSemaSource() {}
375
376void HLSLExternalSemaSource::InitializeSema(Sema &S) {
377  SemaPtr = &S;
378  ASTContext &AST = SemaPtr->getASTContext();
379  // If the translation unit has external storage force external decls to load.
380  if (AST.getTranslationUnitDecl()->hasExternalLexicalStorage())
381    (void)AST.getTranslationUnitDecl()->decls_begin();
382
383  IdentifierInfo &HLSL = AST.Idents.get("hlsl", tok::TokenKind::identifier);
384  LookupResult Result(S, &HLSL, SourceLocation(), Sema::LookupNamespaceName);
385  NamespaceDecl *PrevDecl = nullptr;
386  if (S.LookupQualifiedName(Result, AST.getTranslationUnitDecl()))
387    PrevDecl = Result.getAsSingle<NamespaceDecl>();
388  HLSLNamespace = NamespaceDecl::Create(
389      AST, AST.getTranslationUnitDecl(), /*Inline=*/false, SourceLocation(),
390      SourceLocation(), &HLSL, PrevDecl, /*Nested=*/false);
391  HLSLNamespace->setImplicit(true);
392  HLSLNamespace->setHasExternalLexicalStorage();
393  AST.getTranslationUnitDecl()->addDecl(HLSLNamespace);
394
395  // Force external decls in the HLSL namespace to load from the PCH.
396  (void)HLSLNamespace->getCanonicalDecl()->decls_begin();
397  defineTrivialHLSLTypes();
398  forwardDeclareHLSLTypes();
399
400  // This adds a `using namespace hlsl` directive. In DXC, we don't put HLSL's
401  // built in types inside a namespace, but we are planning to change that in
402  // the near future. In order to be source compatible older versions of HLSL
403  // will need to implicitly use the hlsl namespace. For now in clang everything
404  // will get added to the namespace, and we can remove the using directive for
405  // future language versions to match HLSL's evolution.
406  auto *UsingDecl = UsingDirectiveDecl::Create(
407      AST, AST.getTranslationUnitDecl(), SourceLocation(), SourceLocation(),
408      NestedNameSpecifierLoc(), SourceLocation(), HLSLNamespace,
409      AST.getTranslationUnitDecl());
410
411  AST.getTranslationUnitDecl()->addDecl(UsingDecl);
412}
413
414void HLSLExternalSemaSource::defineHLSLVectorAlias() {
415  ASTContext &AST = SemaPtr->getASTContext();
416
417  llvm::SmallVector<NamedDecl *> TemplateParams;
418
419  auto *TypeParam = TemplateTypeParmDecl::Create(
420      AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 0,
421      &AST.Idents.get("element", tok::TokenKind::identifier), false, false);
422  TypeParam->setDefaultArgument(AST.getTrivialTypeSourceInfo(AST.FloatTy));
423
424  TemplateParams.emplace_back(TypeParam);
425
426  auto *SizeParam = NonTypeTemplateParmDecl::Create(
427      AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 1,
428      &AST.Idents.get("element_count", tok::TokenKind::identifier), AST.IntTy,
429      false, AST.getTrivialTypeSourceInfo(AST.IntTy));
430  Expr *LiteralExpr =
431      IntegerLiteral::Create(AST, llvm::APInt(AST.getIntWidth(AST.IntTy), 4),
432                             AST.IntTy, SourceLocation());
433  SizeParam->setDefaultArgument(LiteralExpr);
434  TemplateParams.emplace_back(SizeParam);
435
436  auto *ParamList =
437      TemplateParameterList::Create(AST, SourceLocation(), SourceLocation(),
438                                    TemplateParams, SourceLocation(), nullptr);
439
440  IdentifierInfo &II = AST.Idents.get("vector", tok::TokenKind::identifier);
441
442  QualType AliasType = AST.getDependentSizedExtVectorType(
443      AST.getTemplateTypeParmType(0, 0, false, TypeParam),
444      DeclRefExpr::Create(
445          AST, NestedNameSpecifierLoc(), SourceLocation(), SizeParam, false,
446          DeclarationNameInfo(SizeParam->getDeclName(), SourceLocation()),
447          AST.IntTy, VK_LValue),
448      SourceLocation());
449
450  auto *Record = TypeAliasDecl::Create(AST, HLSLNamespace, SourceLocation(),
451                                       SourceLocation(), &II,
452                                       AST.getTrivialTypeSourceInfo(AliasType));
453  Record->setImplicit(true);
454
455  auto *Template =
456      TypeAliasTemplateDecl::Create(AST, HLSLNamespace, SourceLocation(),
457                                    Record->getIdentifier(), ParamList, Record);
458
459  Record->setDescribedAliasTemplate(Template);
460  Template->setImplicit(true);
461  Template->setLexicalDeclContext(Record->getDeclContext());
462  HLSLNamespace->addDecl(Template);
463}
464
465void HLSLExternalSemaSource::defineTrivialHLSLTypes() {
466  defineHLSLVectorAlias();
467
468  ResourceDecl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "Resource")
469                     .startDefinition()
470                     .addHandleMember(AccessSpecifier::AS_public)
471                     .completeDefinition()
472                     .Record;
473}
474
475void HLSLExternalSemaSource::forwardDeclareHLSLTypes() {
476  CXXRecordDecl *Decl;
477  Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "RWBuffer")
478             .addTemplateArgumentList()
479             .addTypeParameter("element_type", SemaPtr->getASTContext().FloatTy)
480             .finalizeTemplateArgs()
481             .Record;
482  if (!Decl->isCompleteDefinition())
483    Completions.insert(
484        std::make_pair(Decl->getCanonicalDecl(),
485                       std::bind(&HLSLExternalSemaSource::completeBufferType,
486                                 this, std::placeholders::_1)));
487}
488
489void HLSLExternalSemaSource::CompleteType(TagDecl *Tag) {
490  if (!isa<CXXRecordDecl>(Tag))
491    return;
492  auto Record = cast<CXXRecordDecl>(Tag);
493
494  // If this is a specialization, we need to get the underlying templated
495  // declaration and complete that.
496  if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(Record))
497    Record = TDecl->getSpecializedTemplate()->getTemplatedDecl();
498  Record = Record->getCanonicalDecl();
499  auto It = Completions.find(Record);
500  if (It == Completions.end())
501    return;
502  It->second(Record);
503}
504
505void HLSLExternalSemaSource::completeBufferType(CXXRecordDecl *Record) {
506  BuiltinTypeDeclBuilder(Record)
507      .addHandleMember()
508      .addDefaultHandleConstructor(*SemaPtr, ResourceClass::UAV)
509      .addArraySubscriptOperators()
510      .annotateResourceClass(HLSLResourceAttr::UAV,
511                             HLSLResourceAttr::TypedBuffer)
512      .completeDefinition();
513}
514