1//===-- SemaCoroutine.cpp - Semantic Analysis for Coroutines --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9//  This file implements semantic analysis for C++ Coroutines.
10//
11//  This file contains references to sections of the Coroutines TS, which
12//  can be found at http://wg21.link/coroutines.
13//
14//===----------------------------------------------------------------------===//
15
16#include "CoroutineStmtBuilder.h"
17#include "clang/AST/ASTLambda.h"
18#include "clang/AST/Decl.h"
19#include "clang/AST/ExprCXX.h"
20#include "clang/AST/StmtCXX.h"
21#include "clang/Basic/Builtins.h"
22#include "clang/Lex/Preprocessor.h"
23#include "clang/Sema/Initialization.h"
24#include "clang/Sema/Overload.h"
25#include "clang/Sema/ScopeInfo.h"
26#include "clang/Sema/SemaInternal.h"
27
28using namespace clang;
29using namespace sema;
30
31static LookupResult lookupMember(Sema &S, const char *Name, CXXRecordDecl *RD,
32                                 SourceLocation Loc, bool &Res) {
33  DeclarationName DN = S.PP.getIdentifierInfo(Name);
34  LookupResult LR(S, DN, Loc, Sema::LookupMemberName);
35  // Suppress diagnostics when a private member is selected. The same warnings
36  // will be produced again when building the call.
37  LR.suppressDiagnostics();
38  Res = S.LookupQualifiedName(LR, RD);
39  return LR;
40}
41
42static bool lookupMember(Sema &S, const char *Name, CXXRecordDecl *RD,
43                         SourceLocation Loc) {
44  bool Res;
45  lookupMember(S, Name, RD, Loc, Res);
46  return Res;
47}
48
49/// Look up the std::coroutine_traits<...>::promise_type for the given
50/// function type.
51static QualType lookupPromiseType(Sema &S, const FunctionDecl *FD,
52                                  SourceLocation KwLoc) {
53  const FunctionProtoType *FnType = FD->getType()->castAs<FunctionProtoType>();
54  const SourceLocation FuncLoc = FD->getLocation();
55  // FIXME: Cache std::coroutine_traits once we've found it.
56  NamespaceDecl *StdExp = S.lookupStdExperimentalNamespace();
57  if (!StdExp) {
58    S.Diag(KwLoc, diag::err_implied_coroutine_type_not_found)
59        << "std::experimental::coroutine_traits";
60    return QualType();
61  }
62
63  ClassTemplateDecl *CoroTraits = S.lookupCoroutineTraits(KwLoc, FuncLoc);
64  if (!CoroTraits) {
65    return QualType();
66  }
67
68  // Form template argument list for coroutine_traits<R, P1, P2, ...> according
69  // to [dcl.fct.def.coroutine]3
70  TemplateArgumentListInfo Args(KwLoc, KwLoc);
71  auto AddArg = [&](QualType T) {
72    Args.addArgument(TemplateArgumentLoc(
73        TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, KwLoc)));
74  };
75  AddArg(FnType->getReturnType());
76  // If the function is a non-static member function, add the type
77  // of the implicit object parameter before the formal parameters.
78  if (auto *MD = dyn_cast<CXXMethodDecl>(FD)) {
79    if (MD->isInstance()) {
80      // [over.match.funcs]4
81      // For non-static member functions, the type of the implicit object
82      // parameter is
83      //  -- "lvalue reference to cv X" for functions declared without a
84      //      ref-qualifier or with the & ref-qualifier
85      //  -- "rvalue reference to cv X" for functions declared with the &&
86      //      ref-qualifier
87      QualType T = MD->getThisType()->castAs<PointerType>()->getPointeeType();
88      T = FnType->getRefQualifier() == RQ_RValue
89              ? S.Context.getRValueReferenceType(T)
90              : S.Context.getLValueReferenceType(T, /*SpelledAsLValue*/ true);
91      AddArg(T);
92    }
93  }
94  for (QualType T : FnType->getParamTypes())
95    AddArg(T);
96
97  // Build the template-id.
98  QualType CoroTrait =
99      S.CheckTemplateIdType(TemplateName(CoroTraits), KwLoc, Args);
100  if (CoroTrait.isNull())
101    return QualType();
102  if (S.RequireCompleteType(KwLoc, CoroTrait,
103                            diag::err_coroutine_type_missing_specialization))
104    return QualType();
105
106  auto *RD = CoroTrait->getAsCXXRecordDecl();
107  assert(RD && "specialization of class template is not a class?");
108
109  // Look up the ::promise_type member.
110  LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), KwLoc,
111                 Sema::LookupOrdinaryName);
112  S.LookupQualifiedName(R, RD);
113  auto *Promise = R.getAsSingle<TypeDecl>();
114  if (!Promise) {
115    S.Diag(FuncLoc,
116           diag::err_implied_std_coroutine_traits_promise_type_not_found)
117        << RD;
118    return QualType();
119  }
120  // The promise type is required to be a class type.
121  QualType PromiseType = S.Context.getTypeDeclType(Promise);
122
123  auto buildElaboratedType = [&]() {
124    auto *NNS = NestedNameSpecifier::Create(S.Context, nullptr, StdExp);
125    NNS = NestedNameSpecifier::Create(S.Context, NNS, false,
126                                      CoroTrait.getTypePtr());
127    return S.Context.getElaboratedType(ETK_None, NNS, PromiseType);
128  };
129
130  if (!PromiseType->getAsCXXRecordDecl()) {
131    S.Diag(FuncLoc,
132           diag::err_implied_std_coroutine_traits_promise_type_not_class)
133        << buildElaboratedType();
134    return QualType();
135  }
136  if (S.RequireCompleteType(FuncLoc, buildElaboratedType(),
137                            diag::err_coroutine_promise_type_incomplete))
138    return QualType();
139
140  return PromiseType;
141}
142
143/// Look up the std::experimental::coroutine_handle<PromiseType>.
144static QualType lookupCoroutineHandleType(Sema &S, QualType PromiseType,
145                                          SourceLocation Loc) {
146  if (PromiseType.isNull())
147    return QualType();
148
149  NamespaceDecl *StdExp = S.lookupStdExperimentalNamespace();
150  assert(StdExp && "Should already be diagnosed");
151
152  LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_handle"),
153                      Loc, Sema::LookupOrdinaryName);
154  if (!S.LookupQualifiedName(Result, StdExp)) {
155    S.Diag(Loc, diag::err_implied_coroutine_type_not_found)
156        << "std::experimental::coroutine_handle";
157    return QualType();
158  }
159
160  ClassTemplateDecl *CoroHandle = Result.getAsSingle<ClassTemplateDecl>();
161  if (!CoroHandle) {
162    Result.suppressDiagnostics();
163    // We found something weird. Complain about the first thing we found.
164    NamedDecl *Found = *Result.begin();
165    S.Diag(Found->getLocation(), diag::err_malformed_std_coroutine_handle);
166    return QualType();
167  }
168
169  // Form template argument list for coroutine_handle<Promise>.
170  TemplateArgumentListInfo Args(Loc, Loc);
171  Args.addArgument(TemplateArgumentLoc(
172      TemplateArgument(PromiseType),
173      S.Context.getTrivialTypeSourceInfo(PromiseType, Loc)));
174
175  // Build the template-id.
176  QualType CoroHandleType =
177      S.CheckTemplateIdType(TemplateName(CoroHandle), Loc, Args);
178  if (CoroHandleType.isNull())
179    return QualType();
180  if (S.RequireCompleteType(Loc, CoroHandleType,
181                            diag::err_coroutine_type_missing_specialization))
182    return QualType();
183
184  return CoroHandleType;
185}
186
187static bool isValidCoroutineContext(Sema &S, SourceLocation Loc,
188                                    StringRef Keyword) {
189  // [expr.await]p2 dictates that 'co_await' and 'co_yield' must be used within
190  // a function body.
191  // FIXME: This also covers [expr.await]p2: "An await-expression shall not
192  // appear in a default argument." But the diagnostic QoI here could be
193  // improved to inform the user that default arguments specifically are not
194  // allowed.
195  auto *FD = dyn_cast<FunctionDecl>(S.CurContext);
196  if (!FD) {
197    S.Diag(Loc, isa<ObjCMethodDecl>(S.CurContext)
198                    ? diag::err_coroutine_objc_method
199                    : diag::err_coroutine_outside_function) << Keyword;
200    return false;
201  }
202
203  // An enumeration for mapping the diagnostic type to the correct diagnostic
204  // selection index.
205  enum InvalidFuncDiag {
206    DiagCtor = 0,
207    DiagDtor,
208    DiagMain,
209    DiagConstexpr,
210    DiagAutoRet,
211    DiagVarargs,
212    DiagConsteval,
213  };
214  bool Diagnosed = false;
215  auto DiagInvalid = [&](InvalidFuncDiag ID) {
216    S.Diag(Loc, diag::err_coroutine_invalid_func_context) << ID << Keyword;
217    Diagnosed = true;
218    return false;
219  };
220
221  // Diagnose when a constructor, destructor
222  // or the function 'main' are declared as a coroutine.
223  auto *MD = dyn_cast<CXXMethodDecl>(FD);
224  // [class.ctor]p11: "A constructor shall not be a coroutine."
225  if (MD && isa<CXXConstructorDecl>(MD))
226    return DiagInvalid(DiagCtor);
227  // [class.dtor]p17: "A destructor shall not be a coroutine."
228  else if (MD && isa<CXXDestructorDecl>(MD))
229    return DiagInvalid(DiagDtor);
230  // [basic.start.main]p3: "The function main shall not be a coroutine."
231  else if (FD->isMain())
232    return DiagInvalid(DiagMain);
233
234  // Emit a diagnostics for each of the following conditions which is not met.
235  // [expr.const]p2: "An expression e is a core constant expression unless the
236  // evaluation of e [...] would evaluate one of the following expressions:
237  // [...] an await-expression [...] a yield-expression."
238  if (FD->isConstexpr())
239    DiagInvalid(FD->isConsteval() ? DiagConsteval : DiagConstexpr);
240  // [dcl.spec.auto]p15: "A function declared with a return type that uses a
241  // placeholder type shall not be a coroutine."
242  if (FD->getReturnType()->isUndeducedType())
243    DiagInvalid(DiagAutoRet);
244  // [dcl.fct.def.coroutine]p1: "The parameter-declaration-clause of the
245  // coroutine shall not terminate with an ellipsis that is not part of a
246  // parameter-declaration."
247  if (FD->isVariadic())
248    DiagInvalid(DiagVarargs);
249
250  return !Diagnosed;
251}
252
253static ExprResult buildOperatorCoawaitLookupExpr(Sema &SemaRef, Scope *S,
254                                                 SourceLocation Loc) {
255  DeclarationName OpName =
256      SemaRef.Context.DeclarationNames.getCXXOperatorName(OO_Coawait);
257  LookupResult Operators(SemaRef, OpName, SourceLocation(),
258                         Sema::LookupOperatorName);
259  SemaRef.LookupName(Operators, S);
260
261  assert(!Operators.isAmbiguous() && "Operator lookup cannot be ambiguous");
262  const auto &Functions = Operators.asUnresolvedSet();
263  bool IsOverloaded =
264      Functions.size() > 1 ||
265      (Functions.size() == 1 && isa<FunctionTemplateDecl>(*Functions.begin()));
266  Expr *CoawaitOp = UnresolvedLookupExpr::Create(
267      SemaRef.Context, /*NamingClass*/ nullptr, NestedNameSpecifierLoc(),
268      DeclarationNameInfo(OpName, Loc), /*RequiresADL*/ true, IsOverloaded,
269      Functions.begin(), Functions.end());
270  assert(CoawaitOp);
271  return CoawaitOp;
272}
273
274/// Build a call to 'operator co_await' if there is a suitable operator for
275/// the given expression.
276static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, SourceLocation Loc,
277                                           Expr *E,
278                                           UnresolvedLookupExpr *Lookup) {
279  UnresolvedSet<16> Functions;
280  Functions.append(Lookup->decls_begin(), Lookup->decls_end());
281  return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E);
282}
283
284static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S,
285                                           SourceLocation Loc, Expr *E) {
286  ExprResult R = buildOperatorCoawaitLookupExpr(SemaRef, S, Loc);
287  if (R.isInvalid())
288    return ExprError();
289  return buildOperatorCoawaitCall(SemaRef, Loc, E,
290                                  cast<UnresolvedLookupExpr>(R.get()));
291}
292
293static Expr *buildBuiltinCall(Sema &S, SourceLocation Loc, Builtin::ID Id,
294                              MultiExprArg CallArgs) {
295  StringRef Name = S.Context.BuiltinInfo.getName(Id);
296  LookupResult R(S, &S.Context.Idents.get(Name), Loc, Sema::LookupOrdinaryName);
297  S.LookupName(R, S.TUScope, /*AllowBuiltinCreation=*/true);
298
299  auto *BuiltInDecl = R.getAsSingle<FunctionDecl>();
300  assert(BuiltInDecl && "failed to find builtin declaration");
301
302  ExprResult DeclRef =
303      S.BuildDeclRefExpr(BuiltInDecl, BuiltInDecl->getType(), VK_LValue, Loc);
304  assert(DeclRef.isUsable() && "Builtin reference cannot fail");
305
306  ExprResult Call =
307      S.BuildCallExpr(/*Scope=*/nullptr, DeclRef.get(), Loc, CallArgs, Loc);
308
309  assert(!Call.isInvalid() && "Call to builtin cannot fail!");
310  return Call.get();
311}
312
313static ExprResult buildCoroutineHandle(Sema &S, QualType PromiseType,
314                                       SourceLocation Loc) {
315  QualType CoroHandleType = lookupCoroutineHandleType(S, PromiseType, Loc);
316  if (CoroHandleType.isNull())
317    return ExprError();
318
319  DeclContext *LookupCtx = S.computeDeclContext(CoroHandleType);
320  LookupResult Found(S, &S.PP.getIdentifierTable().get("from_address"), Loc,
321                     Sema::LookupOrdinaryName);
322  if (!S.LookupQualifiedName(Found, LookupCtx)) {
323    S.Diag(Loc, diag::err_coroutine_handle_missing_member)
324        << "from_address";
325    return ExprError();
326  }
327
328  Expr *FramePtr =
329      buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_frame, {});
330
331  CXXScopeSpec SS;
332  ExprResult FromAddr =
333      S.BuildDeclarationNameExpr(SS, Found, /*NeedsADL=*/false);
334  if (FromAddr.isInvalid())
335    return ExprError();
336
337  return S.BuildCallExpr(nullptr, FromAddr.get(), Loc, FramePtr, Loc);
338}
339
340struct ReadySuspendResumeResult {
341  enum AwaitCallType { ACT_Ready, ACT_Suspend, ACT_Resume };
342  Expr *Results[3];
343  OpaqueValueExpr *OpaqueValue;
344  bool IsInvalid;
345};
346
347static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc,
348                                  StringRef Name, MultiExprArg Args) {
349  DeclarationNameInfo NameInfo(&S.PP.getIdentifierTable().get(Name), Loc);
350
351  // FIXME: Fix BuildMemberReferenceExpr to take a const CXXScopeSpec&.
352  CXXScopeSpec SS;
353  ExprResult Result = S.BuildMemberReferenceExpr(
354      Base, Base->getType(), Loc, /*IsPtr=*/false, SS,
355      SourceLocation(), nullptr, NameInfo, /*TemplateArgs=*/nullptr,
356      /*Scope=*/nullptr);
357  if (Result.isInvalid())
358    return ExprError();
359
360  // We meant exactly what we asked for. No need for typo correction.
361  if (auto *TE = dyn_cast<TypoExpr>(Result.get())) {
362    S.clearDelayedTypo(TE);
363    S.Diag(Loc, diag::err_no_member)
364        << NameInfo.getName() << Base->getType()->getAsCXXRecordDecl()
365        << Base->getSourceRange();
366    return ExprError();
367  }
368
369  return S.BuildCallExpr(nullptr, Result.get(), Loc, Args, Loc, nullptr);
370}
371
372// See if return type is coroutine-handle and if so, invoke builtin coro-resume
373// on its address. This is to enable experimental support for coroutine-handle
374// returning await_suspend that results in a guaranteed tail call to the target
375// coroutine.
376static Expr *maybeTailCall(Sema &S, QualType RetType, Expr *E,
377                           SourceLocation Loc) {
378  if (RetType->isReferenceType())
379    return nullptr;
380  Type const *T = RetType.getTypePtr();
381  if (!T->isClassType() && !T->isStructureType())
382    return nullptr;
383
384  // FIXME: Add convertability check to coroutine_handle<>. Possibly via
385  // EvaluateBinaryTypeTrait(BTT_IsConvertible, ...) which is at the moment
386  // a private function in SemaExprCXX.cpp
387
388  ExprResult AddressExpr = buildMemberCall(S, E, Loc, "address", None);
389  if (AddressExpr.isInvalid())
390    return nullptr;
391
392  Expr *JustAddress = AddressExpr.get();
393  // FIXME: Check that the type of AddressExpr is void*
394  return buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_resume,
395                          JustAddress);
396}
397
398/// Build calls to await_ready, await_suspend, and await_resume for a co_await
399/// expression.
400static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, VarDecl *CoroPromise,
401                                                  SourceLocation Loc, Expr *E) {
402  OpaqueValueExpr *Operand = new (S.Context)
403      OpaqueValueExpr(Loc, E->getType(), VK_LValue, E->getObjectKind(), E);
404
405  // Assume invalid until we see otherwise.
406  ReadySuspendResumeResult Calls = {{}, Operand, /*IsInvalid=*/true};
407
408  ExprResult CoroHandleRes = buildCoroutineHandle(S, CoroPromise->getType(), Loc);
409  if (CoroHandleRes.isInvalid())
410    return Calls;
411  Expr *CoroHandle = CoroHandleRes.get();
412
413  const StringRef Funcs[] = {"await_ready", "await_suspend", "await_resume"};
414  MultiExprArg Args[] = {None, CoroHandle, None};
415  for (size_t I = 0, N = llvm::array_lengthof(Funcs); I != N; ++I) {
416    ExprResult Result = buildMemberCall(S, Operand, Loc, Funcs[I], Args[I]);
417    if (Result.isInvalid())
418      return Calls;
419    Calls.Results[I] = Result.get();
420  }
421
422  // Assume the calls are valid; all further checking should make them invalid.
423  Calls.IsInvalid = false;
424
425  using ACT = ReadySuspendResumeResult::AwaitCallType;
426  CallExpr *AwaitReady = cast<CallExpr>(Calls.Results[ACT::ACT_Ready]);
427  if (!AwaitReady->getType()->isDependentType()) {
428    // [expr.await]p3 [...]
429    // ��� await-ready is the expression e.await_ready(), contextually converted
430    // to bool.
431    ExprResult Conv = S.PerformContextuallyConvertToBool(AwaitReady);
432    if (Conv.isInvalid()) {
433      S.Diag(AwaitReady->getDirectCallee()->getBeginLoc(),
434             diag::note_await_ready_no_bool_conversion);
435      S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required)
436          << AwaitReady->getDirectCallee() << E->getSourceRange();
437      Calls.IsInvalid = true;
438    }
439    Calls.Results[ACT::ACT_Ready] = Conv.get();
440  }
441  CallExpr *AwaitSuspend = cast<CallExpr>(Calls.Results[ACT::ACT_Suspend]);
442  if (!AwaitSuspend->getType()->isDependentType()) {
443    // [expr.await]p3 [...]
444    //   - await-suspend is the expression e.await_suspend(h), which shall be
445    //     a prvalue of type void or bool.
446    QualType RetType = AwaitSuspend->getCallReturnType(S.Context);
447
448    // Experimental support for coroutine_handle returning await_suspend.
449    if (Expr *TailCallSuspend = maybeTailCall(S, RetType, AwaitSuspend, Loc))
450      Calls.Results[ACT::ACT_Suspend] = TailCallSuspend;
451    else {
452      // non-class prvalues always have cv-unqualified types
453      if (RetType->isReferenceType() ||
454          (!RetType->isBooleanType() && !RetType->isVoidType())) {
455        S.Diag(AwaitSuspend->getCalleeDecl()->getLocation(),
456               diag::err_await_suspend_invalid_return_type)
457            << RetType;
458        S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required)
459            << AwaitSuspend->getDirectCallee();
460        Calls.IsInvalid = true;
461      }
462    }
463  }
464
465  return Calls;
466}
467
468static ExprResult buildPromiseCall(Sema &S, VarDecl *Promise,
469                                   SourceLocation Loc, StringRef Name,
470                                   MultiExprArg Args) {
471
472  // Form a reference to the promise.
473  ExprResult PromiseRef = S.BuildDeclRefExpr(
474      Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc);
475  if (PromiseRef.isInvalid())
476    return ExprError();
477
478  return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args);
479}
480
481VarDecl *Sema::buildCoroutinePromise(SourceLocation Loc) {
482  assert(isa<FunctionDecl>(CurContext) && "not in a function scope");
483  auto *FD = cast<FunctionDecl>(CurContext);
484  bool IsThisDependentType = [&] {
485    if (auto *MD = dyn_cast_or_null<CXXMethodDecl>(FD))
486      return MD->isInstance() && MD->getThisType()->isDependentType();
487    else
488      return false;
489  }();
490
491  QualType T = FD->getType()->isDependentType() || IsThisDependentType
492                   ? Context.DependentTy
493                   : lookupPromiseType(*this, FD, Loc);
494  if (T.isNull())
495    return nullptr;
496
497  auto *VD = VarDecl::Create(Context, FD, FD->getLocation(), FD->getLocation(),
498                             &PP.getIdentifierTable().get("__promise"), T,
499                             Context.getTrivialTypeSourceInfo(T, Loc), SC_None);
500  CheckVariableDeclarationType(VD);
501  if (VD->isInvalidDecl())
502    return nullptr;
503
504  auto *ScopeInfo = getCurFunction();
505  // Build a list of arguments, based on the coroutine functions arguments,
506  // that will be passed to the promise type's constructor.
507  llvm::SmallVector<Expr *, 4> CtorArgExprs;
508
509  // Add implicit object parameter.
510  if (auto *MD = dyn_cast<CXXMethodDecl>(FD)) {
511    if (MD->isInstance() && !isLambdaCallOperator(MD)) {
512      ExprResult ThisExpr = ActOnCXXThis(Loc);
513      if (ThisExpr.isInvalid())
514        return nullptr;
515      ThisExpr = CreateBuiltinUnaryOp(Loc, UO_Deref, ThisExpr.get());
516      if (ThisExpr.isInvalid())
517        return nullptr;
518      CtorArgExprs.push_back(ThisExpr.get());
519    }
520  }
521
522  auto &Moves = ScopeInfo->CoroutineParameterMoves;
523  for (auto *PD : FD->parameters()) {
524    if (PD->getType()->isDependentType())
525      continue;
526
527    auto RefExpr = ExprEmpty();
528    auto Move = Moves.find(PD);
529    assert(Move != Moves.end() &&
530           "Coroutine function parameter not inserted into move map");
531    // If a reference to the function parameter exists in the coroutine
532    // frame, use that reference.
533    auto *MoveDecl =
534        cast<VarDecl>(cast<DeclStmt>(Move->second)->getSingleDecl());
535    RefExpr =
536        BuildDeclRefExpr(MoveDecl, MoveDecl->getType().getNonReferenceType(),
537                         ExprValueKind::VK_LValue, FD->getLocation());
538    if (RefExpr.isInvalid())
539      return nullptr;
540    CtorArgExprs.push_back(RefExpr.get());
541  }
542
543  // Create an initialization sequence for the promise type using the
544  // constructor arguments, wrapped in a parenthesized list expression.
545  Expr *PLE = ParenListExpr::Create(Context, FD->getLocation(),
546                                    CtorArgExprs, FD->getLocation());
547  InitializedEntity Entity = InitializedEntity::InitializeVariable(VD);
548  InitializationKind Kind = InitializationKind::CreateForInit(
549      VD->getLocation(), /*DirectInit=*/true, PLE);
550  InitializationSequence InitSeq(*this, Entity, Kind, CtorArgExprs,
551                                 /*TopLevelOfInitList=*/false,
552                                 /*TreatUnavailableAsInvalid=*/false);
553
554  // Attempt to initialize the promise type with the arguments.
555  // If that fails, fall back to the promise type's default constructor.
556  if (InitSeq) {
557    ExprResult Result = InitSeq.Perform(*this, Entity, Kind, CtorArgExprs);
558    if (Result.isInvalid()) {
559      VD->setInvalidDecl();
560    } else if (Result.get()) {
561      VD->setInit(MaybeCreateExprWithCleanups(Result.get()));
562      VD->setInitStyle(VarDecl::CallInit);
563      CheckCompleteVariableDeclaration(VD);
564    }
565  } else
566    ActOnUninitializedDecl(VD);
567
568  FD->addDecl(VD);
569  return VD;
570}
571
572/// Check that this is a context in which a coroutine suspension can appear.
573static FunctionScopeInfo *checkCoroutineContext(Sema &S, SourceLocation Loc,
574                                                StringRef Keyword,
575                                                bool IsImplicit = false) {
576  if (!isValidCoroutineContext(S, Loc, Keyword))
577    return nullptr;
578
579  assert(isa<FunctionDecl>(S.CurContext) && "not in a function scope");
580
581  auto *ScopeInfo = S.getCurFunction();
582  assert(ScopeInfo && "missing function scope for function");
583
584  if (ScopeInfo->FirstCoroutineStmtLoc.isInvalid() && !IsImplicit)
585    ScopeInfo->setFirstCoroutineStmt(Loc, Keyword);
586
587  if (ScopeInfo->CoroutinePromise)
588    return ScopeInfo;
589
590  if (!S.buildCoroutineParameterMoves(Loc))
591    return nullptr;
592
593  ScopeInfo->CoroutinePromise = S.buildCoroutinePromise(Loc);
594  if (!ScopeInfo->CoroutinePromise)
595    return nullptr;
596
597  return ScopeInfo;
598}
599
600bool Sema::ActOnCoroutineBodyStart(Scope *SC, SourceLocation KWLoc,
601                                   StringRef Keyword) {
602  if (!checkCoroutineContext(*this, KWLoc, Keyword))
603    return false;
604  auto *ScopeInfo = getCurFunction();
605  assert(ScopeInfo->CoroutinePromise);
606
607  // If we have existing coroutine statements then we have already built
608  // the initial and final suspend points.
609  if (!ScopeInfo->NeedsCoroutineSuspends)
610    return true;
611
612  ScopeInfo->setNeedsCoroutineSuspends(false);
613
614  auto *Fn = cast<FunctionDecl>(CurContext);
615  SourceLocation Loc = Fn->getLocation();
616  // Build the initial suspend point
617  auto buildSuspends = [&](StringRef Name) mutable -> StmtResult {
618    ExprResult Suspend =
619        buildPromiseCall(*this, ScopeInfo->CoroutinePromise, Loc, Name, None);
620    if (Suspend.isInvalid())
621      return StmtError();
622    Suspend = buildOperatorCoawaitCall(*this, SC, Loc, Suspend.get());
623    if (Suspend.isInvalid())
624      return StmtError();
625    Suspend = BuildResolvedCoawaitExpr(Loc, Suspend.get(),
626                                       /*IsImplicit*/ true);
627    Suspend = ActOnFinishFullExpr(Suspend.get(), /*DiscardedValue*/ false);
628    if (Suspend.isInvalid()) {
629      Diag(Loc, diag::note_coroutine_promise_suspend_implicitly_required)
630          << ((Name == "initial_suspend") ? 0 : 1);
631      Diag(KWLoc, diag::note_declared_coroutine_here) << Keyword;
632      return StmtError();
633    }
634    return cast<Stmt>(Suspend.get());
635  };
636
637  StmtResult InitSuspend = buildSuspends("initial_suspend");
638  if (InitSuspend.isInvalid())
639    return true;
640
641  StmtResult FinalSuspend = buildSuspends("final_suspend");
642  if (FinalSuspend.isInvalid())
643    return true;
644
645  ScopeInfo->setCoroutineSuspends(InitSuspend.get(), FinalSuspend.get());
646
647  return true;
648}
649
650// Recursively walks up the scope hierarchy until either a 'catch' or a function
651// scope is found, whichever comes first.
652static bool isWithinCatchScope(Scope *S) {
653  // 'co_await' and 'co_yield' keywords are disallowed within catch blocks, but
654  // lambdas that use 'co_await' are allowed. The loop below ends when a
655  // function scope is found in order to ensure the following behavior:
656  //
657  // void foo() {      // <- function scope
658  //   try {           //
659  //     co_await x;   // <- 'co_await' is OK within a function scope
660  //   } catch {       // <- catch scope
661  //     co_await x;   // <- 'co_await' is not OK within a catch scope
662  //     []() {        // <- function scope
663  //       co_await x; // <- 'co_await' is OK within a function scope
664  //     }();
665  //   }
666  // }
667  while (S && !(S->getFlags() & Scope::FnScope)) {
668    if (S->getFlags() & Scope::CatchScope)
669      return true;
670    S = S->getParent();
671  }
672  return false;
673}
674
675// [expr.await]p2, emphasis added: "An await-expression shall appear only in
676// a *potentially evaluated* expression within the compound-statement of a
677// function-body *outside of a handler* [...] A context within a function
678// where an await-expression can appear is called a suspension context of the
679// function."
680static void checkSuspensionContext(Sema &S, SourceLocation Loc,
681                                   StringRef Keyword) {
682  // First emphasis of [expr.await]p2: must be a potentially evaluated context.
683  // That is, 'co_await' and 'co_yield' cannot appear in subexpressions of
684  // \c sizeof.
685  if (S.isUnevaluatedContext())
686    S.Diag(Loc, diag::err_coroutine_unevaluated_context) << Keyword;
687
688  // Second emphasis of [expr.await]p2: must be outside of an exception handler.
689  if (isWithinCatchScope(S.getCurScope()))
690    S.Diag(Loc, diag::err_coroutine_within_handler) << Keyword;
691}
692
693ExprResult Sema::ActOnCoawaitExpr(Scope *S, SourceLocation Loc, Expr *E) {
694  if (!ActOnCoroutineBodyStart(S, Loc, "co_await")) {
695    CorrectDelayedTyposInExpr(E);
696    return ExprError();
697  }
698
699  checkSuspensionContext(*this, Loc, "co_await");
700
701  if (E->getType()->isPlaceholderType()) {
702    ExprResult R = CheckPlaceholderExpr(E);
703    if (R.isInvalid()) return ExprError();
704    E = R.get();
705  }
706  ExprResult Lookup = buildOperatorCoawaitLookupExpr(*this, S, Loc);
707  if (Lookup.isInvalid())
708    return ExprError();
709  return BuildUnresolvedCoawaitExpr(Loc, E,
710                                   cast<UnresolvedLookupExpr>(Lookup.get()));
711}
712
713ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *E,
714                                            UnresolvedLookupExpr *Lookup) {
715  auto *FSI = checkCoroutineContext(*this, Loc, "co_await");
716  if (!FSI)
717    return ExprError();
718
719  if (E->getType()->isPlaceholderType()) {
720    ExprResult R = CheckPlaceholderExpr(E);
721    if (R.isInvalid())
722      return ExprError();
723    E = R.get();
724  }
725
726  auto *Promise = FSI->CoroutinePromise;
727  if (Promise->getType()->isDependentType()) {
728    Expr *Res =
729        new (Context) DependentCoawaitExpr(Loc, Context.DependentTy, E, Lookup);
730    return Res;
731  }
732
733  auto *RD = Promise->getType()->getAsCXXRecordDecl();
734  if (lookupMember(*this, "await_transform", RD, Loc)) {
735    ExprResult R = buildPromiseCall(*this, Promise, Loc, "await_transform", E);
736    if (R.isInvalid()) {
737      Diag(Loc,
738           diag::note_coroutine_promise_implicit_await_transform_required_here)
739          << E->getSourceRange();
740      return ExprError();
741    }
742    E = R.get();
743  }
744  ExprResult Awaitable = buildOperatorCoawaitCall(*this, Loc, E, Lookup);
745  if (Awaitable.isInvalid())
746    return ExprError();
747
748  return BuildResolvedCoawaitExpr(Loc, Awaitable.get());
749}
750
751ExprResult Sema::BuildResolvedCoawaitExpr(SourceLocation Loc, Expr *E,
752                                  bool IsImplicit) {
753  auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await", IsImplicit);
754  if (!Coroutine)
755    return ExprError();
756
757  if (E->getType()->isPlaceholderType()) {
758    ExprResult R = CheckPlaceholderExpr(E);
759    if (R.isInvalid()) return ExprError();
760    E = R.get();
761  }
762
763  if (E->getType()->isDependentType()) {
764    Expr *Res = new (Context)
765        CoawaitExpr(Loc, Context.DependentTy, E, IsImplicit);
766    return Res;
767  }
768
769  // If the expression is a temporary, materialize it as an lvalue so that we
770  // can use it multiple times.
771  if (E->getValueKind() == VK_RValue)
772    E = CreateMaterializeTemporaryExpr(E->getType(), E, true);
773
774  // The location of the `co_await` token cannot be used when constructing
775  // the member call expressions since it's before the location of `Expr`, which
776  // is used as the start of the member call expression.
777  SourceLocation CallLoc = E->getExprLoc();
778
779  // Build the await_ready, await_suspend, await_resume calls.
780  ReadySuspendResumeResult RSS =
781      buildCoawaitCalls(*this, Coroutine->CoroutinePromise, CallLoc, E);
782  if (RSS.IsInvalid)
783    return ExprError();
784
785  Expr *Res =
786      new (Context) CoawaitExpr(Loc, E, RSS.Results[0], RSS.Results[1],
787                                RSS.Results[2], RSS.OpaqueValue, IsImplicit);
788
789  return Res;
790}
791
792ExprResult Sema::ActOnCoyieldExpr(Scope *S, SourceLocation Loc, Expr *E) {
793  if (!ActOnCoroutineBodyStart(S, Loc, "co_yield")) {
794    CorrectDelayedTyposInExpr(E);
795    return ExprError();
796  }
797
798  checkSuspensionContext(*this, Loc, "co_yield");
799
800  // Build yield_value call.
801  ExprResult Awaitable = buildPromiseCall(
802      *this, getCurFunction()->CoroutinePromise, Loc, "yield_value", E);
803  if (Awaitable.isInvalid())
804    return ExprError();
805
806  // Build 'operator co_await' call.
807  Awaitable = buildOperatorCoawaitCall(*this, S, Loc, Awaitable.get());
808  if (Awaitable.isInvalid())
809    return ExprError();
810
811  return BuildCoyieldExpr(Loc, Awaitable.get());
812}
813ExprResult Sema::BuildCoyieldExpr(SourceLocation Loc, Expr *E) {
814  auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield");
815  if (!Coroutine)
816    return ExprError();
817
818  if (E->getType()->isPlaceholderType()) {
819    ExprResult R = CheckPlaceholderExpr(E);
820    if (R.isInvalid()) return ExprError();
821    E = R.get();
822  }
823
824  if (E->getType()->isDependentType()) {
825    Expr *Res = new (Context) CoyieldExpr(Loc, Context.DependentTy, E);
826    return Res;
827  }
828
829  // If the expression is a temporary, materialize it as an lvalue so that we
830  // can use it multiple times.
831  if (E->getValueKind() == VK_RValue)
832    E = CreateMaterializeTemporaryExpr(E->getType(), E, true);
833
834  // Build the await_ready, await_suspend, await_resume calls.
835  ReadySuspendResumeResult RSS =
836      buildCoawaitCalls(*this, Coroutine->CoroutinePromise, Loc, E);
837  if (RSS.IsInvalid)
838    return ExprError();
839
840  Expr *Res =
841      new (Context) CoyieldExpr(Loc, E, RSS.Results[0], RSS.Results[1],
842                                RSS.Results[2], RSS.OpaqueValue);
843
844  return Res;
845}
846
847StmtResult Sema::ActOnCoreturnStmt(Scope *S, SourceLocation Loc, Expr *E) {
848  if (!ActOnCoroutineBodyStart(S, Loc, "co_return")) {
849    CorrectDelayedTyposInExpr(E);
850    return StmtError();
851  }
852  return BuildCoreturnStmt(Loc, E);
853}
854
855StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E,
856                                   bool IsImplicit) {
857  auto *FSI = checkCoroutineContext(*this, Loc, "co_return", IsImplicit);
858  if (!FSI)
859    return StmtError();
860
861  if (E && E->getType()->isPlaceholderType() &&
862      !E->getType()->isSpecificPlaceholderType(BuiltinType::Overload)) {
863    ExprResult R = CheckPlaceholderExpr(E);
864    if (R.isInvalid()) return StmtError();
865    E = R.get();
866  }
867
868  // Move the return value if we can
869  if (E) {
870    auto NRVOCandidate = this->getCopyElisionCandidate(E->getType(), E, CES_AsIfByStdMove);
871    if (NRVOCandidate) {
872      InitializedEntity Entity =
873          InitializedEntity::InitializeResult(Loc, E->getType(), NRVOCandidate);
874      ExprResult MoveResult = this->PerformMoveOrCopyInitialization(
875          Entity, NRVOCandidate, E->getType(), E);
876      if (MoveResult.get())
877        E = MoveResult.get();
878    }
879  }
880
881  // FIXME: If the operand is a reference to a variable that's about to go out
882  // of scope, we should treat the operand as an xvalue for this overload
883  // resolution.
884  VarDecl *Promise = FSI->CoroutinePromise;
885  ExprResult PC;
886  if (E && (isa<InitListExpr>(E) || !E->getType()->isVoidType())) {
887    PC = buildPromiseCall(*this, Promise, Loc, "return_value", E);
888  } else {
889    E = MakeFullDiscardedValueExpr(E).get();
890    PC = buildPromiseCall(*this, Promise, Loc, "return_void", None);
891  }
892  if (PC.isInvalid())
893    return StmtError();
894
895  Expr *PCE = ActOnFinishFullExpr(PC.get(), /*DiscardedValue*/ false).get();
896
897  Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE, IsImplicit);
898  return Res;
899}
900
901/// Look up the std::nothrow object.
902static Expr *buildStdNoThrowDeclRef(Sema &S, SourceLocation Loc) {
903  NamespaceDecl *Std = S.getStdNamespace();
904  assert(Std && "Should already be diagnosed");
905
906  LookupResult Result(S, &S.PP.getIdentifierTable().get("nothrow"), Loc,
907                      Sema::LookupOrdinaryName);
908  if (!S.LookupQualifiedName(Result, Std)) {
909    // FIXME: <experimental/coroutine> should have been included already.
910    // If we require it to include <new> then this diagnostic is no longer
911    // needed.
912    S.Diag(Loc, diag::err_implicit_coroutine_std_nothrow_type_not_found);
913    return nullptr;
914  }
915
916  auto *VD = Result.getAsSingle<VarDecl>();
917  if (!VD) {
918    Result.suppressDiagnostics();
919    // We found something weird. Complain about the first thing we found.
920    NamedDecl *Found = *Result.begin();
921    S.Diag(Found->getLocation(), diag::err_malformed_std_nothrow);
922    return nullptr;
923  }
924
925  ExprResult DR = S.BuildDeclRefExpr(VD, VD->getType(), VK_LValue, Loc);
926  if (DR.isInvalid())
927    return nullptr;
928
929  return DR.get();
930}
931
932// Find an appropriate delete for the promise.
933static FunctionDecl *findDeleteForPromise(Sema &S, SourceLocation Loc,
934                                          QualType PromiseType) {
935  FunctionDecl *OperatorDelete = nullptr;
936
937  DeclarationName DeleteName =
938      S.Context.DeclarationNames.getCXXOperatorName(OO_Delete);
939
940  auto *PointeeRD = PromiseType->getAsCXXRecordDecl();
941  assert(PointeeRD && "PromiseType must be a CxxRecordDecl type");
942
943  if (S.FindDeallocationFunction(Loc, PointeeRD, DeleteName, OperatorDelete))
944    return nullptr;
945
946  if (!OperatorDelete) {
947    // Look for a global declaration.
948    const bool CanProvideSize = S.isCompleteType(Loc, PromiseType);
949    const bool Overaligned = false;
950    OperatorDelete = S.FindUsualDeallocationFunction(Loc, CanProvideSize,
951                                                     Overaligned, DeleteName);
952  }
953  S.MarkFunctionReferenced(Loc, OperatorDelete);
954  return OperatorDelete;
955}
956
957
958void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) {
959  FunctionScopeInfo *Fn = getCurFunction();
960  assert(Fn && Fn->isCoroutine() && "not a coroutine");
961  if (!Body) {
962    assert(FD->isInvalidDecl() &&
963           "a null body is only allowed for invalid declarations");
964    return;
965  }
966  // We have a function that uses coroutine keywords, but we failed to build
967  // the promise type.
968  if (!Fn->CoroutinePromise)
969    return FD->setInvalidDecl();
970
971  if (isa<CoroutineBodyStmt>(Body)) {
972    // Nothing todo. the body is already a transformed coroutine body statement.
973    return;
974  }
975
976  // Coroutines [stmt.return]p1:
977  //   A return statement shall not appear in a coroutine.
978  if (Fn->FirstReturnLoc.isValid()) {
979    assert(Fn->FirstCoroutineStmtLoc.isValid() &&
980                   "first coroutine location not set");
981    Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine);
982    Diag(Fn->FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
983            << Fn->getFirstCoroutineStmtKeyword();
984  }
985  CoroutineStmtBuilder Builder(*this, *FD, *Fn, Body);
986  if (Builder.isInvalid() || !Builder.buildStatements())
987    return FD->setInvalidDecl();
988
989  // Build body for the coroutine wrapper statement.
990  Body = CoroutineBodyStmt::Create(Context, Builder);
991}
992
993CoroutineStmtBuilder::CoroutineStmtBuilder(Sema &S, FunctionDecl &FD,
994                                           sema::FunctionScopeInfo &Fn,
995                                           Stmt *Body)
996    : S(S), FD(FD), Fn(Fn), Loc(FD.getLocation()),
997      IsPromiseDependentType(
998          !Fn.CoroutinePromise ||
999          Fn.CoroutinePromise->getType()->isDependentType()) {
1000  this->Body = Body;
1001
1002  for (auto KV : Fn.CoroutineParameterMoves)
1003    this->ParamMovesVector.push_back(KV.second);
1004  this->ParamMoves = this->ParamMovesVector;
1005
1006  if (!IsPromiseDependentType) {
1007    PromiseRecordDecl = Fn.CoroutinePromise->getType()->getAsCXXRecordDecl();
1008    assert(PromiseRecordDecl && "Type should have already been checked");
1009  }
1010  this->IsValid = makePromiseStmt() && makeInitialAndFinalSuspend();
1011}
1012
1013bool CoroutineStmtBuilder::buildStatements() {
1014  assert(this->IsValid && "coroutine already invalid");
1015  this->IsValid = makeReturnObject();
1016  if (this->IsValid && !IsPromiseDependentType)
1017    buildDependentStatements();
1018  return this->IsValid;
1019}
1020
1021bool CoroutineStmtBuilder::buildDependentStatements() {
1022  assert(this->IsValid && "coroutine already invalid");
1023  assert(!this->IsPromiseDependentType &&
1024         "coroutine cannot have a dependent promise type");
1025  this->IsValid = makeOnException() && makeOnFallthrough() &&
1026                  makeGroDeclAndReturnStmt() && makeReturnOnAllocFailure() &&
1027                  makeNewAndDeleteExpr();
1028  return this->IsValid;
1029}
1030
1031bool CoroutineStmtBuilder::makePromiseStmt() {
1032  // Form a declaration statement for the promise declaration, so that AST
1033  // visitors can more easily find it.
1034  StmtResult PromiseStmt =
1035      S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(Fn.CoroutinePromise), Loc, Loc);
1036  if (PromiseStmt.isInvalid())
1037    return false;
1038
1039  this->Promise = PromiseStmt.get();
1040  return true;
1041}
1042
1043bool CoroutineStmtBuilder::makeInitialAndFinalSuspend() {
1044  if (Fn.hasInvalidCoroutineSuspends())
1045    return false;
1046  this->InitialSuspend = cast<Expr>(Fn.CoroutineSuspends.first);
1047  this->FinalSuspend = cast<Expr>(Fn.CoroutineSuspends.second);
1048  return true;
1049}
1050
1051static bool diagReturnOnAllocFailure(Sema &S, Expr *E,
1052                                     CXXRecordDecl *PromiseRecordDecl,
1053                                     FunctionScopeInfo &Fn) {
1054  auto Loc = E->getExprLoc();
1055  if (auto *DeclRef = dyn_cast_or_null<DeclRefExpr>(E)) {
1056    auto *Decl = DeclRef->getDecl();
1057    if (CXXMethodDecl *Method = dyn_cast_or_null<CXXMethodDecl>(Decl)) {
1058      if (Method->isStatic())
1059        return true;
1060      else
1061        Loc = Decl->getLocation();
1062    }
1063  }
1064
1065  S.Diag(
1066      Loc,
1067      diag::err_coroutine_promise_get_return_object_on_allocation_failure)
1068      << PromiseRecordDecl;
1069  S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
1070      << Fn.getFirstCoroutineStmtKeyword();
1071  return false;
1072}
1073
1074bool CoroutineStmtBuilder::makeReturnOnAllocFailure() {
1075  assert(!IsPromiseDependentType &&
1076         "cannot make statement while the promise type is dependent");
1077
1078  // [dcl.fct.def.coroutine]/8
1079  // The unqualified-id get_return_object_on_allocation_failure is looked up in
1080  // the scope of class P by class member access lookup (3.4.5). ...
1081  // If an allocation function returns nullptr, ... the coroutine return value
1082  // is obtained by a call to ... get_return_object_on_allocation_failure().
1083
1084  DeclarationName DN =
1085      S.PP.getIdentifierInfo("get_return_object_on_allocation_failure");
1086  LookupResult Found(S, DN, Loc, Sema::LookupMemberName);
1087  if (!S.LookupQualifiedName(Found, PromiseRecordDecl))
1088    return true;
1089
1090  CXXScopeSpec SS;
1091  ExprResult DeclNameExpr =
1092      S.BuildDeclarationNameExpr(SS, Found, /*NeedsADL=*/false);
1093  if (DeclNameExpr.isInvalid())
1094    return false;
1095
1096  if (!diagReturnOnAllocFailure(S, DeclNameExpr.get(), PromiseRecordDecl, Fn))
1097    return false;
1098
1099  ExprResult ReturnObjectOnAllocationFailure =
1100      S.BuildCallExpr(nullptr, DeclNameExpr.get(), Loc, {}, Loc);
1101  if (ReturnObjectOnAllocationFailure.isInvalid())
1102    return false;
1103
1104  StmtResult ReturnStmt =
1105      S.BuildReturnStmt(Loc, ReturnObjectOnAllocationFailure.get());
1106  if (ReturnStmt.isInvalid()) {
1107    S.Diag(Found.getFoundDecl()->getLocation(), diag::note_member_declared_here)
1108        << DN;
1109    S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
1110        << Fn.getFirstCoroutineStmtKeyword();
1111    return false;
1112  }
1113
1114  this->ReturnStmtOnAllocFailure = ReturnStmt.get();
1115  return true;
1116}
1117
1118bool CoroutineStmtBuilder::makeNewAndDeleteExpr() {
1119  // Form and check allocation and deallocation calls.
1120  assert(!IsPromiseDependentType &&
1121         "cannot make statement while the promise type is dependent");
1122  QualType PromiseType = Fn.CoroutinePromise->getType();
1123
1124  if (S.RequireCompleteType(Loc, PromiseType, diag::err_incomplete_type))
1125    return false;
1126
1127  const bool RequiresNoThrowAlloc = ReturnStmtOnAllocFailure != nullptr;
1128
1129  // [dcl.fct.def.coroutine]/7
1130  // Lookup allocation functions using a parameter list composed of the
1131  // requested size of the coroutine state being allocated, followed by
1132  // the coroutine function's arguments. If a matching allocation function
1133  // exists, use it. Otherwise, use an allocation function that just takes
1134  // the requested size.
1135
1136  FunctionDecl *OperatorNew = nullptr;
1137  FunctionDecl *OperatorDelete = nullptr;
1138  FunctionDecl *UnusedResult = nullptr;
1139  bool PassAlignment = false;
1140  SmallVector<Expr *, 1> PlacementArgs;
1141
1142  // [dcl.fct.def.coroutine]/7
1143  // "The allocation function���s name is looked up in the scope of P.
1144  // [...] If the lookup finds an allocation function in the scope of P,
1145  // overload resolution is performed on a function call created by assembling
1146  // an argument list. The first argument is the amount of space requested,
1147  // and has type std::size_t. The lvalues p1 ... pn are the succeeding
1148  // arguments."
1149  //
1150  // ...where "p1 ... pn" are defined earlier as:
1151  //
1152  // [dcl.fct.def.coroutine]/3
1153  // "For a coroutine f that is a non-static member function, let P1 denote the
1154  // type of the implicit object parameter (13.3.1) and P2 ... Pn be the types
1155  // of the function parameters; otherwise let P1 ... Pn be the types of the
1156  // function parameters. Let p1 ... pn be lvalues denoting those objects."
1157  if (auto *MD = dyn_cast<CXXMethodDecl>(&FD)) {
1158    if (MD->isInstance() && !isLambdaCallOperator(MD)) {
1159      ExprResult ThisExpr = S.ActOnCXXThis(Loc);
1160      if (ThisExpr.isInvalid())
1161        return false;
1162      ThisExpr = S.CreateBuiltinUnaryOp(Loc, UO_Deref, ThisExpr.get());
1163      if (ThisExpr.isInvalid())
1164        return false;
1165      PlacementArgs.push_back(ThisExpr.get());
1166    }
1167  }
1168  for (auto *PD : FD.parameters()) {
1169    if (PD->getType()->isDependentType())
1170      continue;
1171
1172    // Build a reference to the parameter.
1173    auto PDLoc = PD->getLocation();
1174    ExprResult PDRefExpr =
1175        S.BuildDeclRefExpr(PD, PD->getOriginalType().getNonReferenceType(),
1176                           ExprValueKind::VK_LValue, PDLoc);
1177    if (PDRefExpr.isInvalid())
1178      return false;
1179
1180    PlacementArgs.push_back(PDRefExpr.get());
1181  }
1182  S.FindAllocationFunctions(Loc, SourceRange(), /*NewScope*/ Sema::AFS_Class,
1183                            /*DeleteScope*/ Sema::AFS_Both, PromiseType,
1184                            /*isArray*/ false, PassAlignment, PlacementArgs,
1185                            OperatorNew, UnusedResult, /*Diagnose*/ false);
1186
1187  // [dcl.fct.def.coroutine]/7
1188  // "If no matching function is found, overload resolution is performed again
1189  // on a function call created by passing just the amount of space required as
1190  // an argument of type std::size_t."
1191  if (!OperatorNew && !PlacementArgs.empty()) {
1192    PlacementArgs.clear();
1193    S.FindAllocationFunctions(Loc, SourceRange(), /*NewScope*/ Sema::AFS_Class,
1194                              /*DeleteScope*/ Sema::AFS_Both, PromiseType,
1195                              /*isArray*/ false, PassAlignment, PlacementArgs,
1196                              OperatorNew, UnusedResult, /*Diagnose*/ false);
1197  }
1198
1199  // [dcl.fct.def.coroutine]/7
1200  // "The allocation function���s name is looked up in the scope of P. If this
1201  // lookup fails, the allocation function���s name is looked up in the global
1202  // scope."
1203  if (!OperatorNew) {
1204    S.FindAllocationFunctions(Loc, SourceRange(), /*NewScope*/ Sema::AFS_Global,
1205                              /*DeleteScope*/ Sema::AFS_Both, PromiseType,
1206                              /*isArray*/ false, PassAlignment, PlacementArgs,
1207                              OperatorNew, UnusedResult);
1208  }
1209
1210  bool IsGlobalOverload =
1211      OperatorNew && !isa<CXXRecordDecl>(OperatorNew->getDeclContext());
1212  // If we didn't find a class-local new declaration and non-throwing new
1213  // was is required then we need to lookup the non-throwing global operator
1214  // instead.
1215  if (RequiresNoThrowAlloc && (!OperatorNew || IsGlobalOverload)) {
1216    auto *StdNoThrow = buildStdNoThrowDeclRef(S, Loc);
1217    if (!StdNoThrow)
1218      return false;
1219    PlacementArgs = {StdNoThrow};
1220    OperatorNew = nullptr;
1221    S.FindAllocationFunctions(Loc, SourceRange(), /*NewScope*/ Sema::AFS_Both,
1222                              /*DeleteScope*/ Sema::AFS_Both, PromiseType,
1223                              /*isArray*/ false, PassAlignment, PlacementArgs,
1224                              OperatorNew, UnusedResult);
1225  }
1226
1227  if (!OperatorNew)
1228    return false;
1229
1230  if (RequiresNoThrowAlloc) {
1231    const auto *FT = OperatorNew->getType()->castAs<FunctionProtoType>();
1232    if (!FT->isNothrow(/*ResultIfDependent*/ false)) {
1233      S.Diag(OperatorNew->getLocation(),
1234             diag::err_coroutine_promise_new_requires_nothrow)
1235          << OperatorNew;
1236      S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required)
1237          << OperatorNew;
1238      return false;
1239    }
1240  }
1241
1242  if ((OperatorDelete = findDeleteForPromise(S, Loc, PromiseType)) == nullptr)
1243    return false;
1244
1245  Expr *FramePtr =
1246      buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_frame, {});
1247
1248  Expr *FrameSize =
1249      buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_size, {});
1250
1251  // Make new call.
1252
1253  ExprResult NewRef =
1254      S.BuildDeclRefExpr(OperatorNew, OperatorNew->getType(), VK_LValue, Loc);
1255  if (NewRef.isInvalid())
1256    return false;
1257
1258  SmallVector<Expr *, 2> NewArgs(1, FrameSize);
1259  for (auto Arg : PlacementArgs)
1260    NewArgs.push_back(Arg);
1261
1262  ExprResult NewExpr =
1263      S.BuildCallExpr(S.getCurScope(), NewRef.get(), Loc, NewArgs, Loc);
1264  NewExpr = S.ActOnFinishFullExpr(NewExpr.get(), /*DiscardedValue*/ false);
1265  if (NewExpr.isInvalid())
1266    return false;
1267
1268  // Make delete call.
1269
1270  QualType OpDeleteQualType = OperatorDelete->getType();
1271
1272  ExprResult DeleteRef =
1273      S.BuildDeclRefExpr(OperatorDelete, OpDeleteQualType, VK_LValue, Loc);
1274  if (DeleteRef.isInvalid())
1275    return false;
1276
1277  Expr *CoroFree =
1278      buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_free, {FramePtr});
1279
1280  SmallVector<Expr *, 2> DeleteArgs{CoroFree};
1281
1282  // Check if we need to pass the size.
1283  const auto *OpDeleteType =
1284      OpDeleteQualType.getTypePtr()->castAs<FunctionProtoType>();
1285  if (OpDeleteType->getNumParams() > 1)
1286    DeleteArgs.push_back(FrameSize);
1287
1288  ExprResult DeleteExpr =
1289      S.BuildCallExpr(S.getCurScope(), DeleteRef.get(), Loc, DeleteArgs, Loc);
1290  DeleteExpr =
1291      S.ActOnFinishFullExpr(DeleteExpr.get(), /*DiscardedValue*/ false);
1292  if (DeleteExpr.isInvalid())
1293    return false;
1294
1295  this->Allocate = NewExpr.get();
1296  this->Deallocate = DeleteExpr.get();
1297
1298  return true;
1299}
1300
1301bool CoroutineStmtBuilder::makeOnFallthrough() {
1302  assert(!IsPromiseDependentType &&
1303         "cannot make statement while the promise type is dependent");
1304
1305  // [dcl.fct.def.coroutine]/4
1306  // The unqualified-ids 'return_void' and 'return_value' are looked up in
1307  // the scope of class P. If both are found, the program is ill-formed.
1308  bool HasRVoid, HasRValue;
1309  LookupResult LRVoid =
1310      lookupMember(S, "return_void", PromiseRecordDecl, Loc, HasRVoid);
1311  LookupResult LRValue =
1312      lookupMember(S, "return_value", PromiseRecordDecl, Loc, HasRValue);
1313
1314  StmtResult Fallthrough;
1315  if (HasRVoid && HasRValue) {
1316    // FIXME Improve this diagnostic
1317    S.Diag(FD.getLocation(),
1318           diag::err_coroutine_promise_incompatible_return_functions)
1319        << PromiseRecordDecl;
1320    S.Diag(LRVoid.getRepresentativeDecl()->getLocation(),
1321           diag::note_member_first_declared_here)
1322        << LRVoid.getLookupName();
1323    S.Diag(LRValue.getRepresentativeDecl()->getLocation(),
1324           diag::note_member_first_declared_here)
1325        << LRValue.getLookupName();
1326    return false;
1327  } else if (!HasRVoid && !HasRValue) {
1328    // FIXME: The PDTS currently specifies this case as UB, not ill-formed.
1329    // However we still diagnose this as an error since until the PDTS is fixed.
1330    S.Diag(FD.getLocation(),
1331           diag::err_coroutine_promise_requires_return_function)
1332        << PromiseRecordDecl;
1333    S.Diag(PromiseRecordDecl->getLocation(), diag::note_defined_here)
1334        << PromiseRecordDecl;
1335    return false;
1336  } else if (HasRVoid) {
1337    // If the unqualified-id return_void is found, flowing off the end of a
1338    // coroutine is equivalent to a co_return with no operand. Otherwise,
1339    // flowing off the end of a coroutine results in undefined behavior.
1340    Fallthrough = S.BuildCoreturnStmt(FD.getLocation(), nullptr,
1341                                      /*IsImplicit*/false);
1342    Fallthrough = S.ActOnFinishFullStmt(Fallthrough.get());
1343    if (Fallthrough.isInvalid())
1344      return false;
1345  }
1346
1347  this->OnFallthrough = Fallthrough.get();
1348  return true;
1349}
1350
1351bool CoroutineStmtBuilder::makeOnException() {
1352  // Try to form 'p.unhandled_exception();'
1353  assert(!IsPromiseDependentType &&
1354         "cannot make statement while the promise type is dependent");
1355
1356  const bool RequireUnhandledException = S.getLangOpts().CXXExceptions;
1357
1358  if (!lookupMember(S, "unhandled_exception", PromiseRecordDecl, Loc)) {
1359    auto DiagID =
1360        RequireUnhandledException
1361            ? diag::err_coroutine_promise_unhandled_exception_required
1362            : diag::
1363                  warn_coroutine_promise_unhandled_exception_required_with_exceptions;
1364    S.Diag(Loc, DiagID) << PromiseRecordDecl;
1365    S.Diag(PromiseRecordDecl->getLocation(), diag::note_defined_here)
1366        << PromiseRecordDecl;
1367    return !RequireUnhandledException;
1368  }
1369
1370  // If exceptions are disabled, don't try to build OnException.
1371  if (!S.getLangOpts().CXXExceptions)
1372    return true;
1373
1374  ExprResult UnhandledException = buildPromiseCall(S, Fn.CoroutinePromise, Loc,
1375                                                   "unhandled_exception", None);
1376  UnhandledException = S.ActOnFinishFullExpr(UnhandledException.get(), Loc,
1377                                             /*DiscardedValue*/ false);
1378  if (UnhandledException.isInvalid())
1379    return false;
1380
1381  // Since the body of the coroutine will be wrapped in try-catch, it will
1382  // be incompatible with SEH __try if present in a function.
1383  if (!S.getLangOpts().Borland && Fn.FirstSEHTryLoc.isValid()) {
1384    S.Diag(Fn.FirstSEHTryLoc, diag::err_seh_in_a_coroutine_with_cxx_exceptions);
1385    S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
1386        << Fn.getFirstCoroutineStmtKeyword();
1387    return false;
1388  }
1389
1390  this->OnException = UnhandledException.get();
1391  return true;
1392}
1393
1394bool CoroutineStmtBuilder::makeReturnObject() {
1395  // Build implicit 'p.get_return_object()' expression and form initialization
1396  // of return type from it.
1397  ExprResult ReturnObject =
1398      buildPromiseCall(S, Fn.CoroutinePromise, Loc, "get_return_object", None);
1399  if (ReturnObject.isInvalid())
1400    return false;
1401
1402  this->ReturnValue = ReturnObject.get();
1403  return true;
1404}
1405
1406static void noteMemberDeclaredHere(Sema &S, Expr *E, FunctionScopeInfo &Fn) {
1407  if (auto *MbrRef = dyn_cast<CXXMemberCallExpr>(E)) {
1408    auto *MethodDecl = MbrRef->getMethodDecl();
1409    S.Diag(MethodDecl->getLocation(), diag::note_member_declared_here)
1410        << MethodDecl;
1411  }
1412  S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
1413      << Fn.getFirstCoroutineStmtKeyword();
1414}
1415
1416bool CoroutineStmtBuilder::makeGroDeclAndReturnStmt() {
1417  assert(!IsPromiseDependentType &&
1418         "cannot make statement while the promise type is dependent");
1419  assert(this->ReturnValue && "ReturnValue must be already formed");
1420
1421  QualType const GroType = this->ReturnValue->getType();
1422  assert(!GroType->isDependentType() &&
1423         "get_return_object type must no longer be dependent");
1424
1425  QualType const FnRetType = FD.getReturnType();
1426  assert(!FnRetType->isDependentType() &&
1427         "get_return_object type must no longer be dependent");
1428
1429  if (FnRetType->isVoidType()) {
1430    ExprResult Res =
1431        S.ActOnFinishFullExpr(this->ReturnValue, Loc, /*DiscardedValue*/ false);
1432    if (Res.isInvalid())
1433      return false;
1434
1435    this->ResultDecl = Res.get();
1436    return true;
1437  }
1438
1439  if (GroType->isVoidType()) {
1440    // Trigger a nice error message.
1441    InitializedEntity Entity =
1442        InitializedEntity::InitializeResult(Loc, FnRetType, false);
1443    S.PerformMoveOrCopyInitialization(Entity, nullptr, FnRetType, ReturnValue);
1444    noteMemberDeclaredHere(S, ReturnValue, Fn);
1445    return false;
1446  }
1447
1448  auto *GroDecl = VarDecl::Create(
1449      S.Context, &FD, FD.getLocation(), FD.getLocation(),
1450      &S.PP.getIdentifierTable().get("__coro_gro"), GroType,
1451      S.Context.getTrivialTypeSourceInfo(GroType, Loc), SC_None);
1452
1453  S.CheckVariableDeclarationType(GroDecl);
1454  if (GroDecl->isInvalidDecl())
1455    return false;
1456
1457  InitializedEntity Entity = InitializedEntity::InitializeVariable(GroDecl);
1458  ExprResult Res = S.PerformMoveOrCopyInitialization(Entity, nullptr, GroType,
1459                                                     this->ReturnValue);
1460  if (Res.isInvalid())
1461    return false;
1462
1463  Res = S.ActOnFinishFullExpr(Res.get(), /*DiscardedValue*/ false);
1464  if (Res.isInvalid())
1465    return false;
1466
1467  S.AddInitializerToDecl(GroDecl, Res.get(),
1468                         /*DirectInit=*/false);
1469
1470  S.FinalizeDeclaration(GroDecl);
1471
1472  // Form a declaration statement for the return declaration, so that AST
1473  // visitors can more easily find it.
1474  StmtResult GroDeclStmt =
1475      S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(GroDecl), Loc, Loc);
1476  if (GroDeclStmt.isInvalid())
1477    return false;
1478
1479  this->ResultDecl = GroDeclStmt.get();
1480
1481  ExprResult declRef = S.BuildDeclRefExpr(GroDecl, GroType, VK_LValue, Loc);
1482  if (declRef.isInvalid())
1483    return false;
1484
1485  StmtResult ReturnStmt = S.BuildReturnStmt(Loc, declRef.get());
1486  if (ReturnStmt.isInvalid()) {
1487    noteMemberDeclaredHere(S, ReturnValue, Fn);
1488    return false;
1489  }
1490  if (cast<clang::ReturnStmt>(ReturnStmt.get())->getNRVOCandidate() == GroDecl)
1491    GroDecl->setNRVOVariable(true);
1492
1493  this->ReturnStmt = ReturnStmt.get();
1494  return true;
1495}
1496
1497// Create a static_cast\<T&&>(expr).
1498static Expr *castForMoving(Sema &S, Expr *E, QualType T = QualType()) {
1499  if (T.isNull())
1500    T = E->getType();
1501  QualType TargetType = S.BuildReferenceType(
1502      T, /*SpelledAsLValue*/ false, SourceLocation(), DeclarationName());
1503  SourceLocation ExprLoc = E->getBeginLoc();
1504  TypeSourceInfo *TargetLoc =
1505      S.Context.getTrivialTypeSourceInfo(TargetType, ExprLoc);
1506
1507  return S
1508      .BuildCXXNamedCast(ExprLoc, tok::kw_static_cast, TargetLoc, E,
1509                         SourceRange(ExprLoc, ExprLoc), E->getSourceRange())
1510      .get();
1511}
1512
1513/// Build a variable declaration for move parameter.
1514static VarDecl *buildVarDecl(Sema &S, SourceLocation Loc, QualType Type,
1515                             IdentifierInfo *II) {
1516  TypeSourceInfo *TInfo = S.Context.getTrivialTypeSourceInfo(Type, Loc);
1517  VarDecl *Decl = VarDecl::Create(S.Context, S.CurContext, Loc, Loc, II, Type,
1518                                  TInfo, SC_None);
1519  Decl->setImplicit();
1520  return Decl;
1521}
1522
1523// Build statements that move coroutine function parameters to the coroutine
1524// frame, and store them on the function scope info.
1525bool Sema::buildCoroutineParameterMoves(SourceLocation Loc) {
1526  assert(isa<FunctionDecl>(CurContext) && "not in a function scope");
1527  auto *FD = cast<FunctionDecl>(CurContext);
1528
1529  auto *ScopeInfo = getCurFunction();
1530  if (!ScopeInfo->CoroutineParameterMoves.empty())
1531    return false;
1532
1533  for (auto *PD : FD->parameters()) {
1534    if (PD->getType()->isDependentType())
1535      continue;
1536
1537    ExprResult PDRefExpr =
1538        BuildDeclRefExpr(PD, PD->getType().getNonReferenceType(),
1539                         ExprValueKind::VK_LValue, Loc); // FIXME: scope?
1540    if (PDRefExpr.isInvalid())
1541      return false;
1542
1543    Expr *CExpr = nullptr;
1544    if (PD->getType()->getAsCXXRecordDecl() ||
1545        PD->getType()->isRValueReferenceType())
1546      CExpr = castForMoving(*this, PDRefExpr.get());
1547    else
1548      CExpr = PDRefExpr.get();
1549
1550    auto D = buildVarDecl(*this, Loc, PD->getType(), PD->getIdentifier());
1551    AddInitializerToDecl(D, CExpr, /*DirectInit=*/true);
1552
1553    // Convert decl to a statement.
1554    StmtResult Stmt = ActOnDeclStmt(ConvertDeclToDeclGroup(D), Loc, Loc);
1555    if (Stmt.isInvalid())
1556      return false;
1557
1558    ScopeInfo->CoroutineParameterMoves.insert(std::make_pair(PD, Stmt.get()));
1559  }
1560  return true;
1561}
1562
1563StmtResult Sema::BuildCoroutineBodyStmt(CoroutineBodyStmt::CtorArgs Args) {
1564  CoroutineBodyStmt *Res = CoroutineBodyStmt::Create(Context, Args);
1565  if (!Res)
1566    return StmtError();
1567  return Res;
1568}
1569
1570ClassTemplateDecl *Sema::lookupCoroutineTraits(SourceLocation KwLoc,
1571                                               SourceLocation FuncLoc) {
1572  if (!StdCoroutineTraitsCache) {
1573    if (auto StdExp = lookupStdExperimentalNamespace()) {
1574      LookupResult Result(*this,
1575                          &PP.getIdentifierTable().get("coroutine_traits"),
1576                          FuncLoc, LookupOrdinaryName);
1577      if (!LookupQualifiedName(Result, StdExp)) {
1578        Diag(KwLoc, diag::err_implied_coroutine_type_not_found)
1579            << "std::experimental::coroutine_traits";
1580        return nullptr;
1581      }
1582      if (!(StdCoroutineTraitsCache =
1583                Result.getAsSingle<ClassTemplateDecl>())) {
1584        Result.suppressDiagnostics();
1585        NamedDecl *Found = *Result.begin();
1586        Diag(Found->getLocation(), diag::err_malformed_std_coroutine_traits);
1587        return nullptr;
1588      }
1589    }
1590  }
1591  return StdCoroutineTraitsCache;
1592}
1593