SemaCoroutine.cpp revision 296417
1//===--- SemaCoroutines.cpp - Semantic Analysis for Coroutines ------------===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10//  This file implements semantic analysis for C++ Coroutines.
11//
12//===----------------------------------------------------------------------===//
13
14#include "clang/Sema/SemaInternal.h"
15#include "clang/AST/Decl.h"
16#include "clang/AST/ExprCXX.h"
17#include "clang/AST/StmtCXX.h"
18#include "clang/Lex/Preprocessor.h"
19#include "clang/Sema/Initialization.h"
20#include "clang/Sema/Overload.h"
21using namespace clang;
22using namespace sema;
23
24/// Look up the std::coroutine_traits<...>::promise_type for the given
25/// function type.
26static QualType lookupPromiseType(Sema &S, const FunctionProtoType *FnType,
27                                  SourceLocation Loc) {
28  // FIXME: Cache std::coroutine_traits once we've found it.
29  NamespaceDecl *Std = S.getStdNamespace();
30  if (!Std) {
31    S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found);
32    return QualType();
33  }
34
35  LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_traits"),
36                      Loc, Sema::LookupOrdinaryName);
37  if (!S.LookupQualifiedName(Result, Std)) {
38    S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found);
39    return QualType();
40  }
41
42  ClassTemplateDecl *CoroTraits = Result.getAsSingle<ClassTemplateDecl>();
43  if (!CoroTraits) {
44    Result.suppressDiagnostics();
45    // We found something weird. Complain about the first thing we found.
46    NamedDecl *Found = *Result.begin();
47    S.Diag(Found->getLocation(), diag::err_malformed_std_coroutine_traits);
48    return QualType();
49  }
50
51  // Form template argument list for coroutine_traits<R, P1, P2, ...>.
52  TemplateArgumentListInfo Args(Loc, Loc);
53  Args.addArgument(TemplateArgumentLoc(
54      TemplateArgument(FnType->getReturnType()),
55      S.Context.getTrivialTypeSourceInfo(FnType->getReturnType(), Loc)));
56  // FIXME: If the function is a non-static member function, add the type
57  // of the implicit object parameter before the formal parameters.
58  for (QualType T : FnType->getParamTypes())
59    Args.addArgument(TemplateArgumentLoc(
60        TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, Loc)));
61
62  // Build the template-id.
63  QualType CoroTrait =
64      S.CheckTemplateIdType(TemplateName(CoroTraits), Loc, Args);
65  if (CoroTrait.isNull())
66    return QualType();
67  if (S.RequireCompleteType(Loc, CoroTrait,
68                            diag::err_coroutine_traits_missing_specialization))
69    return QualType();
70
71  CXXRecordDecl *RD = CoroTrait->getAsCXXRecordDecl();
72  assert(RD && "specialization of class template is not a class?");
73
74  // Look up the ::promise_type member.
75  LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), Loc,
76                 Sema::LookupOrdinaryName);
77  S.LookupQualifiedName(R, RD);
78  auto *Promise = R.getAsSingle<TypeDecl>();
79  if (!Promise) {
80    S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_found)
81      << RD;
82    return QualType();
83  }
84
85  // The promise type is required to be a class type.
86  QualType PromiseType = S.Context.getTypeDeclType(Promise);
87  if (!PromiseType->getAsCXXRecordDecl()) {
88    // Use the fully-qualified name of the type.
89    auto *NNS = NestedNameSpecifier::Create(S.Context, nullptr, Std);
90    NNS = NestedNameSpecifier::Create(S.Context, NNS, false,
91                                      CoroTrait.getTypePtr());
92    PromiseType = S.Context.getElaboratedType(ETK_None, NNS, PromiseType);
93
94    S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_class)
95      << PromiseType;
96    return QualType();
97  }
98
99  return PromiseType;
100}
101
102/// Check that this is a context in which a coroutine suspension can appear.
103static FunctionScopeInfo *
104checkCoroutineContext(Sema &S, SourceLocation Loc, StringRef Keyword) {
105  // 'co_await' and 'co_yield' are not permitted in unevaluated operands.
106  if (S.isUnevaluatedContext()) {
107    S.Diag(Loc, diag::err_coroutine_unevaluated_context) << Keyword;
108    return nullptr;
109  }
110
111  // Any other usage must be within a function.
112  // FIXME: Reject a coroutine with a deduced return type.
113  auto *FD = dyn_cast<FunctionDecl>(S.CurContext);
114  if (!FD) {
115    S.Diag(Loc, isa<ObjCMethodDecl>(S.CurContext)
116                    ? diag::err_coroutine_objc_method
117                    : diag::err_coroutine_outside_function) << Keyword;
118  } else if (isa<CXXConstructorDecl>(FD) || isa<CXXDestructorDecl>(FD)) {
119    // Coroutines TS [special]/6:
120    //   A special member function shall not be a coroutine.
121    //
122    // FIXME: We assume that this really means that a coroutine cannot
123    //        be a constructor or destructor.
124    S.Diag(Loc, diag::err_coroutine_ctor_dtor)
125      << isa<CXXDestructorDecl>(FD) << Keyword;
126  } else if (FD->isConstexpr()) {
127    S.Diag(Loc, diag::err_coroutine_constexpr) << Keyword;
128  } else if (FD->isVariadic()) {
129    S.Diag(Loc, diag::err_coroutine_varargs) << Keyword;
130  } else {
131    auto *ScopeInfo = S.getCurFunction();
132    assert(ScopeInfo && "missing function scope for function");
133
134    // If we don't have a promise variable, build one now.
135    if (!ScopeInfo->CoroutinePromise) {
136      QualType T =
137          FD->getType()->isDependentType()
138              ? S.Context.DependentTy
139              : lookupPromiseType(S, FD->getType()->castAs<FunctionProtoType>(),
140                                  Loc);
141      if (T.isNull())
142        return nullptr;
143
144      // Create and default-initialize the promise.
145      ScopeInfo->CoroutinePromise =
146          VarDecl::Create(S.Context, FD, FD->getLocation(), FD->getLocation(),
147                          &S.PP.getIdentifierTable().get("__promise"), T,
148                          S.Context.getTrivialTypeSourceInfo(T, Loc), SC_None);
149      S.CheckVariableDeclarationType(ScopeInfo->CoroutinePromise);
150      if (!ScopeInfo->CoroutinePromise->isInvalidDecl())
151        S.ActOnUninitializedDecl(ScopeInfo->CoroutinePromise, false);
152    }
153
154    return ScopeInfo;
155  }
156
157  return nullptr;
158}
159
160/// Build a call to 'operator co_await' if there is a suitable operator for
161/// the given expression.
162static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S,
163                                           SourceLocation Loc, Expr *E) {
164  UnresolvedSet<16> Functions;
165  SemaRef.LookupOverloadedOperatorName(OO_Coawait, S, E->getType(), QualType(),
166                                       Functions);
167  return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E);
168}
169
170struct ReadySuspendResumeResult {
171  bool IsInvalid;
172  Expr *Results[3];
173};
174
175static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc,
176                                  StringRef Name,
177                                  MutableArrayRef<Expr *> Args) {
178  DeclarationNameInfo NameInfo(&S.PP.getIdentifierTable().get(Name), Loc);
179
180  // FIXME: Fix BuildMemberReferenceExpr to take a const CXXScopeSpec&.
181  CXXScopeSpec SS;
182  ExprResult Result = S.BuildMemberReferenceExpr(
183      Base, Base->getType(), Loc, /*IsPtr=*/false, SS,
184      SourceLocation(), nullptr, NameInfo, /*TemplateArgs=*/nullptr,
185      /*Scope=*/nullptr);
186  if (Result.isInvalid())
187    return ExprError();
188
189  return S.ActOnCallExpr(nullptr, Result.get(), Loc, Args, Loc, nullptr);
190}
191
192/// Build calls to await_ready, await_suspend, and await_resume for a co_await
193/// expression.
194static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, SourceLocation Loc,
195                                                  Expr *E) {
196  // Assume invalid until we see otherwise.
197  ReadySuspendResumeResult Calls = {true, {}};
198
199  const StringRef Funcs[] = {"await_ready", "await_suspend", "await_resume"};
200  for (size_t I = 0, N = llvm::array_lengthof(Funcs); I != N; ++I) {
201    Expr *Operand = new (S.Context) OpaqueValueExpr(
202        Loc, E->getType(), VK_LValue, E->getObjectKind(), E);
203
204    // FIXME: Pass coroutine handle to await_suspend.
205    ExprResult Result = buildMemberCall(S, Operand, Loc, Funcs[I], None);
206    if (Result.isInvalid())
207      return Calls;
208    Calls.Results[I] = Result.get();
209  }
210
211  Calls.IsInvalid = false;
212  return Calls;
213}
214
215ExprResult Sema::ActOnCoawaitExpr(Scope *S, SourceLocation Loc, Expr *E) {
216  if (E->getType()->isPlaceholderType()) {
217    ExprResult R = CheckPlaceholderExpr(E);
218    if (R.isInvalid()) return ExprError();
219    E = R.get();
220  }
221
222  ExprResult Awaitable = buildOperatorCoawaitCall(*this, S, Loc, E);
223  if (Awaitable.isInvalid())
224    return ExprError();
225  return BuildCoawaitExpr(Loc, Awaitable.get());
226}
227ExprResult Sema::BuildCoawaitExpr(SourceLocation Loc, Expr *E) {
228  auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await");
229  if (!Coroutine)
230    return ExprError();
231
232  if (E->getType()->isPlaceholderType()) {
233    ExprResult R = CheckPlaceholderExpr(E);
234    if (R.isInvalid()) return ExprError();
235    E = R.get();
236  }
237
238  if (E->getType()->isDependentType()) {
239    Expr *Res = new (Context) CoawaitExpr(Loc, Context.DependentTy, E);
240    Coroutine->CoroutineStmts.push_back(Res);
241    return Res;
242  }
243
244  // If the expression is a temporary, materialize it as an lvalue so that we
245  // can use it multiple times.
246  if (E->getValueKind() == VK_RValue)
247    E = new (Context) MaterializeTemporaryExpr(E->getType(), E, true);
248
249  // Build the await_ready, await_suspend, await_resume calls.
250  ReadySuspendResumeResult RSS = buildCoawaitCalls(*this, Loc, E);
251  if (RSS.IsInvalid)
252    return ExprError();
253
254  Expr *Res = new (Context) CoawaitExpr(Loc, E, RSS.Results[0], RSS.Results[1],
255                                        RSS.Results[2]);
256  Coroutine->CoroutineStmts.push_back(Res);
257  return Res;
258}
259
260static ExprResult buildPromiseCall(Sema &S, FunctionScopeInfo *Coroutine,
261                                   SourceLocation Loc, StringRef Name,
262                                   MutableArrayRef<Expr *> Args) {
263  assert(Coroutine->CoroutinePromise && "no promise for coroutine");
264
265  // Form a reference to the promise.
266  auto *Promise = Coroutine->CoroutinePromise;
267  ExprResult PromiseRef = S.BuildDeclRefExpr(
268      Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc);
269  if (PromiseRef.isInvalid())
270    return ExprError();
271
272  // Call 'yield_value', passing in E.
273  return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args);
274}
275
276ExprResult Sema::ActOnCoyieldExpr(Scope *S, SourceLocation Loc, Expr *E) {
277  auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield");
278  if (!Coroutine)
279    return ExprError();
280
281  // Build yield_value call.
282  ExprResult Awaitable =
283      buildPromiseCall(*this, Coroutine, Loc, "yield_value", E);
284  if (Awaitable.isInvalid())
285    return ExprError();
286
287  // Build 'operator co_await' call.
288  Awaitable = buildOperatorCoawaitCall(*this, S, Loc, Awaitable.get());
289  if (Awaitable.isInvalid())
290    return ExprError();
291
292  return BuildCoyieldExpr(Loc, Awaitable.get());
293}
294ExprResult Sema::BuildCoyieldExpr(SourceLocation Loc, Expr *E) {
295  auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield");
296  if (!Coroutine)
297    return ExprError();
298
299  if (E->getType()->isPlaceholderType()) {
300    ExprResult R = CheckPlaceholderExpr(E);
301    if (R.isInvalid()) return ExprError();
302    E = R.get();
303  }
304
305  if (E->getType()->isDependentType()) {
306    Expr *Res = new (Context) CoyieldExpr(Loc, Context.DependentTy, E);
307    Coroutine->CoroutineStmts.push_back(Res);
308    return Res;
309  }
310
311  // If the expression is a temporary, materialize it as an lvalue so that we
312  // can use it multiple times.
313  if (E->getValueKind() == VK_RValue)
314    E = new (Context) MaterializeTemporaryExpr(E->getType(), E, true);
315
316  // Build the await_ready, await_suspend, await_resume calls.
317  ReadySuspendResumeResult RSS = buildCoawaitCalls(*this, Loc, E);
318  if (RSS.IsInvalid)
319    return ExprError();
320
321  Expr *Res = new (Context) CoyieldExpr(Loc, E, RSS.Results[0], RSS.Results[1],
322                                        RSS.Results[2]);
323  Coroutine->CoroutineStmts.push_back(Res);
324  return Res;
325}
326
327StmtResult Sema::ActOnCoreturnStmt(SourceLocation Loc, Expr *E) {
328  return BuildCoreturnStmt(Loc, E);
329}
330StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E) {
331  auto *Coroutine = checkCoroutineContext(*this, Loc, "co_return");
332  if (!Coroutine)
333    return StmtError();
334
335  if (E && E->getType()->isPlaceholderType() &&
336      !E->getType()->isSpecificPlaceholderType(BuiltinType::Overload)) {
337    ExprResult R = CheckPlaceholderExpr(E);
338    if (R.isInvalid()) return StmtError();
339    E = R.get();
340  }
341
342  // FIXME: If the operand is a reference to a variable that's about to go out
343  // of scope, we should treat the operand as an xvalue for this overload
344  // resolution.
345  ExprResult PC;
346  if (E && !E->getType()->isVoidType()) {
347    PC = buildPromiseCall(*this, Coroutine, Loc, "return_value", E);
348  } else {
349    E = MakeFullDiscardedValueExpr(E).get();
350    PC = buildPromiseCall(*this, Coroutine, Loc, "return_void", None);
351  }
352  if (PC.isInvalid())
353    return StmtError();
354
355  Expr *PCE = ActOnFinishFullExpr(PC.get()).get();
356
357  Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE);
358  Coroutine->CoroutineStmts.push_back(Res);
359  return Res;
360}
361
362void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) {
363  FunctionScopeInfo *Fn = getCurFunction();
364  assert(Fn && !Fn->CoroutineStmts.empty() && "not a coroutine");
365
366  // Coroutines [stmt.return]p1:
367  //   A return statement shall not appear in a coroutine.
368  if (Fn->FirstReturnLoc.isValid()) {
369    Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine);
370    auto *First = Fn->CoroutineStmts[0];
371    Diag(First->getLocStart(), diag::note_declared_coroutine_here)
372      << (isa<CoawaitExpr>(First) ? 0 :
373          isa<CoyieldExpr>(First) ? 1 : 2);
374  }
375
376  bool AnyCoawaits = false;
377  bool AnyCoyields = false;
378  for (auto *CoroutineStmt : Fn->CoroutineStmts) {
379    AnyCoawaits |= isa<CoawaitExpr>(CoroutineStmt);
380    AnyCoyields |= isa<CoyieldExpr>(CoroutineStmt);
381  }
382
383  if (!AnyCoawaits && !AnyCoyields)
384    Diag(Fn->CoroutineStmts.front()->getLocStart(),
385         diag::ext_coroutine_without_co_await_co_yield);
386
387  SourceLocation Loc = FD->getLocation();
388
389  // Form a declaration statement for the promise declaration, so that AST
390  // visitors can more easily find it.
391  StmtResult PromiseStmt =
392      ActOnDeclStmt(ConvertDeclToDeclGroup(Fn->CoroutinePromise), Loc, Loc);
393  if (PromiseStmt.isInvalid())
394    return FD->setInvalidDecl();
395
396  // Form and check implicit 'co_await p.initial_suspend();' statement.
397  ExprResult InitialSuspend =
398      buildPromiseCall(*this, Fn, Loc, "initial_suspend", None);
399  // FIXME: Support operator co_await here.
400  if (!InitialSuspend.isInvalid())
401    InitialSuspend = BuildCoawaitExpr(Loc, InitialSuspend.get());
402  InitialSuspend = ActOnFinishFullExpr(InitialSuspend.get());
403  if (InitialSuspend.isInvalid())
404    return FD->setInvalidDecl();
405
406  // Form and check implicit 'co_await p.final_suspend();' statement.
407  ExprResult FinalSuspend =
408      buildPromiseCall(*this, Fn, Loc, "final_suspend", None);
409  // FIXME: Support operator co_await here.
410  if (!FinalSuspend.isInvalid())
411    FinalSuspend = BuildCoawaitExpr(Loc, FinalSuspend.get());
412  FinalSuspend = ActOnFinishFullExpr(FinalSuspend.get());
413  if (FinalSuspend.isInvalid())
414    return FD->setInvalidDecl();
415
416  // FIXME: Perform analysis of set_exception call.
417
418  // FIXME: Try to form 'p.return_void();' expression statement to handle
419  // control flowing off the end of the coroutine.
420
421  // Build implicit 'p.get_return_object()' expression and form initialization
422  // of return type from it.
423  ExprResult ReturnObject =
424    buildPromiseCall(*this, Fn, Loc, "get_return_object", None);
425  if (ReturnObject.isInvalid())
426    return FD->setInvalidDecl();
427  QualType RetType = FD->getReturnType();
428  if (!RetType->isDependentType()) {
429    InitializedEntity Entity =
430        InitializedEntity::InitializeResult(Loc, RetType, false);
431    ReturnObject = PerformMoveOrCopyInitialization(Entity, nullptr, RetType,
432                                                   ReturnObject.get());
433    if (ReturnObject.isInvalid())
434      return FD->setInvalidDecl();
435  }
436  ReturnObject = ActOnFinishFullExpr(ReturnObject.get(), Loc);
437  if (ReturnObject.isInvalid())
438    return FD->setInvalidDecl();
439
440  // FIXME: Perform move-initialization of parameters into frame-local copies.
441  SmallVector<Expr*, 16> ParamMoves;
442
443  // Build body for the coroutine wrapper statement.
444  Body = new (Context) CoroutineBodyStmt(
445      Body, PromiseStmt.get(), InitialSuspend.get(), FinalSuspend.get(),
446      /*SetException*/nullptr, /*Fallthrough*/nullptr,
447      ReturnObject.get(), ParamMoves);
448}
449