1//===--- HLSLExternalSemaSource.cpp - HLSL Sema Source --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9//
10//===----------------------------------------------------------------------===//
11
12#include "clang/Sema/HLSLExternalSemaSource.h"
13#include "clang/AST/ASTContext.h"
14#include "clang/AST/Attr.h"
15#include "clang/AST/DeclCXX.h"
16#include "clang/Basic/AttrKinds.h"
17#include "clang/Basic/HLSLRuntime.h"
18#include "clang/Sema/Lookup.h"
19#include "clang/Sema/Sema.h"
20#include "llvm/Frontend/HLSL/HLSLResource.h"
21
22#include <functional>
23
24using namespace clang;
25using namespace llvm::hlsl;
26
27namespace {
28
29struct TemplateParameterListBuilder;
30
31struct BuiltinTypeDeclBuilder {
32  CXXRecordDecl *Record = nullptr;
33  ClassTemplateDecl *Template = nullptr;
34  ClassTemplateDecl *PrevTemplate = nullptr;
35  NamespaceDecl *HLSLNamespace = nullptr;
36  llvm::StringMap<FieldDecl *> Fields;
37
38  BuiltinTypeDeclBuilder(CXXRecordDecl *R) : Record(R) {
39    Record->startDefinition();
40    Template = Record->getDescribedClassTemplate();
41  }
42
43  BuiltinTypeDeclBuilder(Sema &S, NamespaceDecl *Namespace, StringRef Name)
44      : HLSLNamespace(Namespace) {
45    ASTContext &AST = S.getASTContext();
46    IdentifierInfo &II = AST.Idents.get(Name, tok::TokenKind::identifier);
47
48    LookupResult Result(S, &II, SourceLocation(), Sema::LookupTagName);
49    CXXRecordDecl *PrevDecl = nullptr;
50    if (S.LookupQualifiedName(Result, HLSLNamespace)) {
51      NamedDecl *Found = Result.getFoundDecl();
52      if (auto *TD = dyn_cast<ClassTemplateDecl>(Found)) {
53        PrevDecl = TD->getTemplatedDecl();
54        PrevTemplate = TD;
55      } else
56        PrevDecl = dyn_cast<CXXRecordDecl>(Found);
57      assert(PrevDecl && "Unexpected lookup result type.");
58    }
59
60    if (PrevDecl && PrevDecl->isCompleteDefinition()) {
61      Record = PrevDecl;
62      return;
63    }
64
65    Record = CXXRecordDecl::Create(AST, TagDecl::TagKind::Class, HLSLNamespace,
66                                   SourceLocation(), SourceLocation(), &II,
67                                   PrevDecl, true);
68    Record->setImplicit(true);
69    Record->setLexicalDeclContext(HLSLNamespace);
70    Record->setHasExternalLexicalStorage();
71
72    // Don't let anyone derive from built-in types.
73    Record->addAttr(FinalAttr::CreateImplicit(AST, SourceRange(),
74                                              FinalAttr::Keyword_final));
75  }
76
77  ~BuiltinTypeDeclBuilder() {
78    if (HLSLNamespace && !Template && Record->getDeclContext() == HLSLNamespace)
79      HLSLNamespace->addDecl(Record);
80  }
81
82  BuiltinTypeDeclBuilder &
83  addMemberVariable(StringRef Name, QualType Type,
84                    AccessSpecifier Access = AccessSpecifier::AS_private) {
85    if (Record->isCompleteDefinition())
86      return *this;
87    assert(Record->isBeingDefined() &&
88           "Definition must be started before adding members!");
89    ASTContext &AST = Record->getASTContext();
90
91    IdentifierInfo &II = AST.Idents.get(Name, tok::TokenKind::identifier);
92    TypeSourceInfo *MemTySource =
93        AST.getTrivialTypeSourceInfo(Type, SourceLocation());
94    auto *Field = FieldDecl::Create(
95        AST, Record, SourceLocation(), SourceLocation(), &II, Type, MemTySource,
96        nullptr, false, InClassInitStyle::ICIS_NoInit);
97    Field->setAccess(Access);
98    Field->setImplicit(true);
99    Record->addDecl(Field);
100    Fields[Name] = Field;
101    return *this;
102  }
103
104  BuiltinTypeDeclBuilder &
105  addHandleMember(AccessSpecifier Access = AccessSpecifier::AS_private) {
106    if (Record->isCompleteDefinition())
107      return *this;
108    QualType Ty = Record->getASTContext().VoidPtrTy;
109    if (Template) {
110      if (const auto *TTD = dyn_cast<TemplateTypeParmDecl>(
111              Template->getTemplateParameters()->getParam(0)))
112        Ty = Record->getASTContext().getPointerType(
113            QualType(TTD->getTypeForDecl(), 0));
114    }
115    return addMemberVariable("h", Ty, Access);
116  }
117
118  BuiltinTypeDeclBuilder &annotateResourceClass(ResourceClass RC,
119                                                ResourceKind RK, bool IsROV) {
120    if (Record->isCompleteDefinition())
121      return *this;
122    Record->addAttr(HLSLResourceAttr::CreateImplicit(Record->getASTContext(),
123                                                     RC, RK, IsROV));
124    return *this;
125  }
126
127  static DeclRefExpr *lookupBuiltinFunction(ASTContext &AST, Sema &S,
128                                            StringRef Name) {
129    CXXScopeSpec SS;
130    IdentifierInfo &II = AST.Idents.get(Name, tok::TokenKind::identifier);
131    DeclarationNameInfo NameInfo =
132        DeclarationNameInfo(DeclarationName(&II), SourceLocation());
133    LookupResult R(S, NameInfo, Sema::LookupOrdinaryName);
134    S.LookupParsedName(R, S.getCurScope(), &SS, false);
135    assert(R.isSingleResult() &&
136           "Since this is a builtin it should always resolve!");
137    auto *VD = cast<ValueDecl>(R.getFoundDecl());
138    QualType Ty = VD->getType();
139    return DeclRefExpr::Create(AST, NestedNameSpecifierLoc(), SourceLocation(),
140                               VD, false, NameInfo, Ty, VK_PRValue);
141  }
142
143  static Expr *emitResourceClassExpr(ASTContext &AST, ResourceClass RC) {
144    return IntegerLiteral::Create(
145        AST,
146        llvm::APInt(AST.getIntWidth(AST.UnsignedCharTy),
147                    static_cast<uint8_t>(RC)),
148        AST.UnsignedCharTy, SourceLocation());
149  }
150
151  BuiltinTypeDeclBuilder &addDefaultHandleConstructor(Sema &S,
152                                                      ResourceClass RC) {
153    if (Record->isCompleteDefinition())
154      return *this;
155    ASTContext &AST = Record->getASTContext();
156
157    QualType ConstructorType =
158        AST.getFunctionType(AST.VoidTy, {}, FunctionProtoType::ExtProtoInfo());
159
160    CanQualType CanTy = Record->getTypeForDecl()->getCanonicalTypeUnqualified();
161    DeclarationName Name = AST.DeclarationNames.getCXXConstructorName(CanTy);
162    CXXConstructorDecl *Constructor = CXXConstructorDecl::Create(
163        AST, Record, SourceLocation(),
164        DeclarationNameInfo(Name, SourceLocation()), ConstructorType,
165        AST.getTrivialTypeSourceInfo(ConstructorType, SourceLocation()),
166        ExplicitSpecifier(), false, true, false,
167        ConstexprSpecKind::Unspecified);
168
169    DeclRefExpr *Fn =
170        lookupBuiltinFunction(AST, S, "__builtin_hlsl_create_handle");
171
172    Expr *RCExpr = emitResourceClassExpr(AST, RC);
173    Expr *Call = CallExpr::Create(AST, Fn, {RCExpr}, AST.VoidPtrTy, VK_PRValue,
174                                  SourceLocation(), FPOptionsOverride());
175
176    CXXThisExpr *This = CXXThisExpr::Create(
177        AST, SourceLocation(), Constructor->getFunctionObjectParameterType(),
178        true);
179    Expr *Handle = MemberExpr::CreateImplicit(AST, This, false, Fields["h"],
180                                              Fields["h"]->getType(), VK_LValue,
181                                              OK_Ordinary);
182
183    // If the handle isn't a void pointer, cast the builtin result to the
184    // correct type.
185    if (Handle->getType().getCanonicalType() != AST.VoidPtrTy) {
186      Call = CXXStaticCastExpr::Create(
187          AST, Handle->getType(), VK_PRValue, CK_Dependent, Call, nullptr,
188          AST.getTrivialTypeSourceInfo(Handle->getType(), SourceLocation()),
189          FPOptionsOverride(), SourceLocation(), SourceLocation(),
190          SourceRange());
191    }
192
193    BinaryOperator *Assign = BinaryOperator::Create(
194        AST, Handle, Call, BO_Assign, Handle->getType(), VK_LValue, OK_Ordinary,
195        SourceLocation(), FPOptionsOverride());
196
197    Constructor->setBody(
198        CompoundStmt::Create(AST, {Assign}, FPOptionsOverride(),
199                             SourceLocation(), SourceLocation()));
200    Constructor->setAccess(AccessSpecifier::AS_public);
201    Record->addDecl(Constructor);
202    return *this;
203  }
204
205  BuiltinTypeDeclBuilder &addArraySubscriptOperators() {
206    if (Record->isCompleteDefinition())
207      return *this;
208    addArraySubscriptOperator(true);
209    addArraySubscriptOperator(false);
210    return *this;
211  }
212
213  BuiltinTypeDeclBuilder &addArraySubscriptOperator(bool IsConst) {
214    if (Record->isCompleteDefinition())
215      return *this;
216    assert(Fields.count("h") > 0 &&
217           "Subscript operator must be added after the handle.");
218
219    FieldDecl *Handle = Fields["h"];
220    ASTContext &AST = Record->getASTContext();
221
222    assert(Handle->getType().getCanonicalType() != AST.VoidPtrTy &&
223           "Not yet supported for void pointer handles.");
224
225    QualType ElemTy =
226        QualType(Handle->getType()->getPointeeOrArrayElementType(), 0);
227    QualType ReturnTy = ElemTy;
228
229    FunctionProtoType::ExtProtoInfo ExtInfo;
230
231    // Subscript operators return references to elements, const makes the
232    // reference and method const so that the underlying data is not mutable.
233    ReturnTy = AST.getLValueReferenceType(ReturnTy);
234    if (IsConst) {
235      ExtInfo.TypeQuals.addConst();
236      ReturnTy.addConst();
237    }
238
239    QualType MethodTy =
240        AST.getFunctionType(ReturnTy, {AST.UnsignedIntTy}, ExtInfo);
241    auto *TSInfo = AST.getTrivialTypeSourceInfo(MethodTy, SourceLocation());
242    auto *MethodDecl = CXXMethodDecl::Create(
243        AST, Record, SourceLocation(),
244        DeclarationNameInfo(
245            AST.DeclarationNames.getCXXOperatorName(OO_Subscript),
246            SourceLocation()),
247        MethodTy, TSInfo, SC_None, false, false, ConstexprSpecKind::Unspecified,
248        SourceLocation());
249
250    IdentifierInfo &II = AST.Idents.get("Idx", tok::TokenKind::identifier);
251    auto *IdxParam = ParmVarDecl::Create(
252        AST, MethodDecl->getDeclContext(), SourceLocation(), SourceLocation(),
253        &II, AST.UnsignedIntTy,
254        AST.getTrivialTypeSourceInfo(AST.UnsignedIntTy, SourceLocation()),
255        SC_None, nullptr);
256    MethodDecl->setParams({IdxParam});
257
258    // Also add the parameter to the function prototype.
259    auto FnProtoLoc = TSInfo->getTypeLoc().getAs<FunctionProtoTypeLoc>();
260    FnProtoLoc.setParam(0, IdxParam);
261
262    auto *This =
263        CXXThisExpr::Create(AST, SourceLocation(),
264                            MethodDecl->getFunctionObjectParameterType(), true);
265    auto *HandleAccess = MemberExpr::CreateImplicit(
266        AST, This, false, Handle, Handle->getType(), VK_LValue, OK_Ordinary);
267
268    auto *IndexExpr = DeclRefExpr::Create(
269        AST, NestedNameSpecifierLoc(), SourceLocation(), IdxParam, false,
270        DeclarationNameInfo(IdxParam->getDeclName(), SourceLocation()),
271        AST.UnsignedIntTy, VK_PRValue);
272
273    auto *Array =
274        new (AST) ArraySubscriptExpr(HandleAccess, IndexExpr, ElemTy, VK_LValue,
275                                     OK_Ordinary, SourceLocation());
276
277    auto *Return = ReturnStmt::Create(AST, SourceLocation(), Array, nullptr);
278
279    MethodDecl->setBody(CompoundStmt::Create(AST, {Return}, FPOptionsOverride(),
280                                             SourceLocation(),
281                                             SourceLocation()));
282    MethodDecl->setLexicalDeclContext(Record);
283    MethodDecl->setAccess(AccessSpecifier::AS_public);
284    MethodDecl->addAttr(AlwaysInlineAttr::CreateImplicit(
285        AST, SourceRange(), AlwaysInlineAttr::CXX11_clang_always_inline));
286    Record->addDecl(MethodDecl);
287
288    return *this;
289  }
290
291  BuiltinTypeDeclBuilder &startDefinition() {
292    if (Record->isCompleteDefinition())
293      return *this;
294    Record->startDefinition();
295    return *this;
296  }
297
298  BuiltinTypeDeclBuilder &completeDefinition() {
299    if (Record->isCompleteDefinition())
300      return *this;
301    assert(Record->isBeingDefined() &&
302           "Definition must be started before completing it.");
303
304    Record->completeDefinition();
305    return *this;
306  }
307
308  TemplateParameterListBuilder addTemplateArgumentList();
309  BuiltinTypeDeclBuilder &addSimpleTemplateParams(ArrayRef<StringRef> Names);
310};
311
312struct TemplateParameterListBuilder {
313  BuiltinTypeDeclBuilder &Builder;
314  ASTContext &AST;
315  llvm::SmallVector<NamedDecl *> Params;
316
317  TemplateParameterListBuilder(BuiltinTypeDeclBuilder &RB)
318      : Builder(RB), AST(RB.Record->getASTContext()) {}
319
320  ~TemplateParameterListBuilder() { finalizeTemplateArgs(); }
321
322  TemplateParameterListBuilder &
323  addTypeParameter(StringRef Name, QualType DefaultValue = QualType()) {
324    if (Builder.Record->isCompleteDefinition())
325      return *this;
326    unsigned Position = static_cast<unsigned>(Params.size());
327    auto *Decl = TemplateTypeParmDecl::Create(
328        AST, Builder.Record->getDeclContext(), SourceLocation(),
329        SourceLocation(), /* TemplateDepth */ 0, Position,
330        &AST.Idents.get(Name, tok::TokenKind::identifier), /* Typename */ false,
331        /* ParameterPack */ false);
332    if (!DefaultValue.isNull())
333      Decl->setDefaultArgument(AST.getTrivialTypeSourceInfo(DefaultValue));
334
335    Params.emplace_back(Decl);
336    return *this;
337  }
338
339  BuiltinTypeDeclBuilder &finalizeTemplateArgs() {
340    if (Params.empty())
341      return Builder;
342    auto *ParamList =
343        TemplateParameterList::Create(AST, SourceLocation(), SourceLocation(),
344                                      Params, SourceLocation(), nullptr);
345    Builder.Template = ClassTemplateDecl::Create(
346        AST, Builder.Record->getDeclContext(), SourceLocation(),
347        DeclarationName(Builder.Record->getIdentifier()), ParamList,
348        Builder.Record);
349    Builder.Record->setDescribedClassTemplate(Builder.Template);
350    Builder.Template->setImplicit(true);
351    Builder.Template->setLexicalDeclContext(Builder.Record->getDeclContext());
352    // NOTE: setPreviousDecl before addDecl so new decl replace old decl when
353    // make visible.
354    Builder.Template->setPreviousDecl(Builder.PrevTemplate);
355    Builder.Record->getDeclContext()->addDecl(Builder.Template);
356    Params.clear();
357
358    QualType T = Builder.Template->getInjectedClassNameSpecialization();
359    T = AST.getInjectedClassNameType(Builder.Record, T);
360
361    return Builder;
362  }
363};
364} // namespace
365
366TemplateParameterListBuilder BuiltinTypeDeclBuilder::addTemplateArgumentList() {
367  return TemplateParameterListBuilder(*this);
368}
369
370BuiltinTypeDeclBuilder &
371BuiltinTypeDeclBuilder::addSimpleTemplateParams(ArrayRef<StringRef> Names) {
372  TemplateParameterListBuilder Builder = this->addTemplateArgumentList();
373  for (StringRef Name : Names)
374    Builder.addTypeParameter(Name);
375  return Builder.finalizeTemplateArgs();
376}
377
378HLSLExternalSemaSource::~HLSLExternalSemaSource() {}
379
380void HLSLExternalSemaSource::InitializeSema(Sema &S) {
381  SemaPtr = &S;
382  ASTContext &AST = SemaPtr->getASTContext();
383  // If the translation unit has external storage force external decls to load.
384  if (AST.getTranslationUnitDecl()->hasExternalLexicalStorage())
385    (void)AST.getTranslationUnitDecl()->decls_begin();
386
387  IdentifierInfo &HLSL = AST.Idents.get("hlsl", tok::TokenKind::identifier);
388  LookupResult Result(S, &HLSL, SourceLocation(), Sema::LookupNamespaceName);
389  NamespaceDecl *PrevDecl = nullptr;
390  if (S.LookupQualifiedName(Result, AST.getTranslationUnitDecl()))
391    PrevDecl = Result.getAsSingle<NamespaceDecl>();
392  HLSLNamespace = NamespaceDecl::Create(
393      AST, AST.getTranslationUnitDecl(), /*Inline=*/false, SourceLocation(),
394      SourceLocation(), &HLSL, PrevDecl, /*Nested=*/false);
395  HLSLNamespace->setImplicit(true);
396  HLSLNamespace->setHasExternalLexicalStorage();
397  AST.getTranslationUnitDecl()->addDecl(HLSLNamespace);
398
399  // Force external decls in the HLSL namespace to load from the PCH.
400  (void)HLSLNamespace->getCanonicalDecl()->decls_begin();
401  defineTrivialHLSLTypes();
402  defineHLSLTypesWithForwardDeclarations();
403
404  // This adds a `using namespace hlsl` directive. In DXC, we don't put HLSL's
405  // built in types inside a namespace, but we are planning to change that in
406  // the near future. In order to be source compatible older versions of HLSL
407  // will need to implicitly use the hlsl namespace. For now in clang everything
408  // will get added to the namespace, and we can remove the using directive for
409  // future language versions to match HLSL's evolution.
410  auto *UsingDecl = UsingDirectiveDecl::Create(
411      AST, AST.getTranslationUnitDecl(), SourceLocation(), SourceLocation(),
412      NestedNameSpecifierLoc(), SourceLocation(), HLSLNamespace,
413      AST.getTranslationUnitDecl());
414
415  AST.getTranslationUnitDecl()->addDecl(UsingDecl);
416}
417
418void HLSLExternalSemaSource::defineHLSLVectorAlias() {
419  ASTContext &AST = SemaPtr->getASTContext();
420
421  llvm::SmallVector<NamedDecl *> TemplateParams;
422
423  auto *TypeParam = TemplateTypeParmDecl::Create(
424      AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 0,
425      &AST.Idents.get("element", tok::TokenKind::identifier), false, false);
426  TypeParam->setDefaultArgument(AST.getTrivialTypeSourceInfo(AST.FloatTy));
427
428  TemplateParams.emplace_back(TypeParam);
429
430  auto *SizeParam = NonTypeTemplateParmDecl::Create(
431      AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 1,
432      &AST.Idents.get("element_count", tok::TokenKind::identifier), AST.IntTy,
433      false, AST.getTrivialTypeSourceInfo(AST.IntTy));
434  Expr *LiteralExpr =
435      IntegerLiteral::Create(AST, llvm::APInt(AST.getIntWidth(AST.IntTy), 4),
436                             AST.IntTy, SourceLocation());
437  SizeParam->setDefaultArgument(LiteralExpr);
438  TemplateParams.emplace_back(SizeParam);
439
440  auto *ParamList =
441      TemplateParameterList::Create(AST, SourceLocation(), SourceLocation(),
442                                    TemplateParams, SourceLocation(), nullptr);
443
444  IdentifierInfo &II = AST.Idents.get("vector", tok::TokenKind::identifier);
445
446  QualType AliasType = AST.getDependentSizedExtVectorType(
447      AST.getTemplateTypeParmType(0, 0, false, TypeParam),
448      DeclRefExpr::Create(
449          AST, NestedNameSpecifierLoc(), SourceLocation(), SizeParam, false,
450          DeclarationNameInfo(SizeParam->getDeclName(), SourceLocation()),
451          AST.IntTy, VK_LValue),
452      SourceLocation());
453
454  auto *Record = TypeAliasDecl::Create(AST, HLSLNamespace, SourceLocation(),
455                                       SourceLocation(), &II,
456                                       AST.getTrivialTypeSourceInfo(AliasType));
457  Record->setImplicit(true);
458
459  auto *Template =
460      TypeAliasTemplateDecl::Create(AST, HLSLNamespace, SourceLocation(),
461                                    Record->getIdentifier(), ParamList, Record);
462
463  Record->setDescribedAliasTemplate(Template);
464  Template->setImplicit(true);
465  Template->setLexicalDeclContext(Record->getDeclContext());
466  HLSLNamespace->addDecl(Template);
467}
468
469void HLSLExternalSemaSource::defineTrivialHLSLTypes() {
470  defineHLSLVectorAlias();
471
472  ResourceDecl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "Resource")
473                     .startDefinition()
474                     .addHandleMember(AccessSpecifier::AS_public)
475                     .completeDefinition()
476                     .Record;
477}
478
479/// Set up common members and attributes for buffer types
480static BuiltinTypeDeclBuilder setupBufferType(CXXRecordDecl *Decl, Sema &S,
481                                              ResourceClass RC, ResourceKind RK,
482                                              bool IsROV) {
483  return BuiltinTypeDeclBuilder(Decl)
484      .addHandleMember()
485      .addDefaultHandleConstructor(S, RC)
486      .annotateResourceClass(RC, RK, IsROV);
487}
488
489void HLSLExternalSemaSource::defineHLSLTypesWithForwardDeclarations() {
490  CXXRecordDecl *Decl;
491  Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "RWBuffer")
492             .addSimpleTemplateParams({"element_type"})
493             .Record;
494  onCompletion(Decl, [this](CXXRecordDecl *Decl) {
495    setupBufferType(Decl, *SemaPtr, ResourceClass::UAV,
496                    ResourceKind::TypedBuffer, /*IsROV=*/false)
497        .addArraySubscriptOperators()
498        .completeDefinition();
499  });
500
501  Decl =
502      BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "RasterizerOrderedBuffer")
503          .addSimpleTemplateParams({"element_type"})
504          .Record;
505  onCompletion(Decl, [this](CXXRecordDecl *Decl) {
506    setupBufferType(Decl, *SemaPtr, ResourceClass::UAV,
507                    ResourceKind::TypedBuffer, /*IsROV=*/true)
508        .addArraySubscriptOperators()
509        .completeDefinition();
510  });
511}
512
513void HLSLExternalSemaSource::onCompletion(CXXRecordDecl *Record,
514                                          CompletionFunction Fn) {
515  Completions.insert(std::make_pair(Record->getCanonicalDecl(), Fn));
516}
517
518void HLSLExternalSemaSource::CompleteType(TagDecl *Tag) {
519  if (!isa<CXXRecordDecl>(Tag))
520    return;
521  auto Record = cast<CXXRecordDecl>(Tag);
522
523  // If this is a specialization, we need to get the underlying templated
524  // declaration and complete that.
525  if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(Record))
526    Record = TDecl->getSpecializedTemplate()->getTemplatedDecl();
527  Record = Record->getCanonicalDecl();
528  auto It = Completions.find(Record);
529  if (It == Completions.end())
530    return;
531  It->second(Record);
532}
533