1//===- ComparisonCategories.cpp - Three Way Comparison Data -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9//  This file defines the Comparison Category enum and data types, which
10//  store the types and expressions needed to support operator<=>
11//
12//===----------------------------------------------------------------------===//
13
14#include "clang/AST/ComparisonCategories.h"
15#include "clang/AST/ASTContext.h"
16#include "clang/AST/Decl.h"
17#include "clang/AST/DeclCXX.h"
18#include "clang/AST/Type.h"
19#include "llvm/ADT/SmallVector.h"
20#include <optional>
21
22using namespace clang;
23
24std::optional<ComparisonCategoryType>
25clang::getComparisonCategoryForBuiltinCmp(QualType T) {
26  using CCT = ComparisonCategoryType;
27
28  if (T->isIntegralOrEnumerationType())
29    return CCT::StrongOrdering;
30
31  if (T->isRealFloatingType())
32    return CCT::PartialOrdering;
33
34  // C++2a [expr.spaceship]p8: If the composite pointer type is an object
35  // pointer type, p <=> q is of type std::strong_ordering.
36  // Note: this assumes neither operand is a null pointer constant.
37  if (T->isObjectPointerType())
38    return CCT::StrongOrdering;
39
40  // TODO: Extend support for operator<=> to ObjC types.
41  return std::nullopt;
42}
43
44bool ComparisonCategoryInfo::ValueInfo::hasValidIntValue() const {
45  assert(VD && "must have var decl");
46  if (!VD->isUsableInConstantExpressions(VD->getASTContext()))
47    return false;
48
49  // Before we attempt to get the value of the first field, ensure that we
50  // actually have one (and only one) field.
51  auto *Record = VD->getType()->getAsCXXRecordDecl();
52  if (std::distance(Record->field_begin(), Record->field_end()) != 1 ||
53      !Record->field_begin()->getType()->isIntegralOrEnumerationType())
54    return false;
55
56  return true;
57}
58
59/// Attempt to determine the integer value used to represent the comparison
60/// category result by evaluating the initializer for the specified VarDecl as
61/// a constant expression and retrieving the value of the class's first
62/// (and only) field.
63///
64/// Note: The STL types are expected to have the form:
65///    struct X { T value; };
66/// where T is an integral or enumeration type.
67llvm::APSInt ComparisonCategoryInfo::ValueInfo::getIntValue() const {
68  assert(hasValidIntValue() && "must have a valid value");
69  return VD->evaluateValue()->getStructField(0).getInt();
70}
71
72ComparisonCategoryInfo::ValueInfo *ComparisonCategoryInfo::lookupValueInfo(
73    ComparisonCategoryResult ValueKind) const {
74  // Check if we already have a cache entry for this value.
75  auto It = llvm::find_if(
76      Objects, [&](ValueInfo const &Info) { return Info.Kind == ValueKind; });
77  if (It != Objects.end())
78    return &(*It);
79
80  // We don't have a cached result. Lookup the variable declaration and create
81  // a new entry representing it.
82  DeclContextLookupResult Lookup = Record->getCanonicalDecl()->lookup(
83      &Ctx.Idents.get(ComparisonCategories::getResultString(ValueKind)));
84  if (Lookup.empty() || !isa<VarDecl>(Lookup.front()))
85    return nullptr;
86  Objects.emplace_back(ValueKind, cast<VarDecl>(Lookup.front()));
87  return &Objects.back();
88}
89
90static const NamespaceDecl *lookupStdNamespace(const ASTContext &Ctx,
91                                               NamespaceDecl *&StdNS) {
92  if (!StdNS) {
93    DeclContextLookupResult Lookup =
94        Ctx.getTranslationUnitDecl()->lookup(&Ctx.Idents.get("std"));
95    if (!Lookup.empty())
96      StdNS = dyn_cast<NamespaceDecl>(Lookup.front());
97  }
98  return StdNS;
99}
100
101static CXXRecordDecl *lookupCXXRecordDecl(const ASTContext &Ctx,
102                                          const NamespaceDecl *StdNS,
103                                          ComparisonCategoryType Kind) {
104  StringRef Name = ComparisonCategories::getCategoryString(Kind);
105  DeclContextLookupResult Lookup = StdNS->lookup(&Ctx.Idents.get(Name));
106  if (!Lookup.empty())
107    if (CXXRecordDecl *RD = dyn_cast<CXXRecordDecl>(Lookup.front()))
108      return RD;
109  return nullptr;
110}
111
112const ComparisonCategoryInfo *
113ComparisonCategories::lookupInfo(ComparisonCategoryType Kind) const {
114  auto It = Data.find(static_cast<char>(Kind));
115  if (It != Data.end())
116    return &It->second;
117
118  if (const NamespaceDecl *NS = lookupStdNamespace(Ctx, StdNS))
119    if (CXXRecordDecl *RD = lookupCXXRecordDecl(Ctx, NS, Kind))
120      return &Data.try_emplace((char)Kind, Ctx, RD, Kind).first->second;
121
122  return nullptr;
123}
124
125const ComparisonCategoryInfo *
126ComparisonCategories::lookupInfoForType(QualType Ty) const {
127  assert(!Ty.isNull() && "type must be non-null");
128  using CCT = ComparisonCategoryType;
129  auto *RD = Ty->getAsCXXRecordDecl();
130  if (!RD)
131    return nullptr;
132
133  // Check to see if we have information for the specified type cached.
134  const auto *CanonRD = RD->getCanonicalDecl();
135  for (auto &KV : Data) {
136    const ComparisonCategoryInfo &Info = KV.second;
137    if (CanonRD == Info.Record->getCanonicalDecl())
138      return &Info;
139  }
140
141  if (!RD->getEnclosingNamespaceContext()->isStdNamespace())
142    return nullptr;
143
144  // If not, check to see if the decl names a type in namespace std with a name
145  // matching one of the comparison category types.
146  for (unsigned I = static_cast<unsigned>(CCT::First),
147                End = static_cast<unsigned>(CCT::Last);
148       I <= End; ++I) {
149    CCT Kind = static_cast<CCT>(I);
150
151    // We've found the comparison category type. Build a new cache entry for
152    // it.
153    if (getCategoryString(Kind) == RD->getName())
154      return &Data.try_emplace((char)Kind, Ctx, RD, Kind).first->second;
155  }
156
157  // We've found nothing. This isn't a comparison category type.
158  return nullptr;
159}
160
161const ComparisonCategoryInfo &ComparisonCategories::getInfoForType(QualType Ty) const {
162  const ComparisonCategoryInfo *Info = lookupInfoForType(Ty);
163  assert(Info && "info for comparison category not found");
164  return *Info;
165}
166
167QualType ComparisonCategoryInfo::getType() const {
168  assert(Record);
169  return QualType(Record->getTypeForDecl(), 0);
170}
171
172StringRef ComparisonCategories::getCategoryString(ComparisonCategoryType Kind) {
173  using CCKT = ComparisonCategoryType;
174  switch (Kind) {
175  case CCKT::PartialOrdering:
176    return "partial_ordering";
177  case CCKT::WeakOrdering:
178    return "weak_ordering";
179  case CCKT::StrongOrdering:
180    return "strong_ordering";
181  }
182  llvm_unreachable("unhandled cases in switch");
183}
184
185StringRef ComparisonCategories::getResultString(ComparisonCategoryResult Kind) {
186  using CCVT = ComparisonCategoryResult;
187  switch (Kind) {
188  case CCVT::Equal:
189    return "equal";
190  case CCVT::Equivalent:
191    return "equivalent";
192  case CCVT::Less:
193    return "less";
194  case CCVT::Greater:
195    return "greater";
196  case CCVT::Unordered:
197    return "unordered";
198  }
199  llvm_unreachable("unhandled case in switch");
200}
201
202std::vector<ComparisonCategoryResult>
203ComparisonCategories::getPossibleResultsForType(ComparisonCategoryType Type) {
204  using CCT = ComparisonCategoryType;
205  using CCR = ComparisonCategoryResult;
206  std::vector<CCR> Values;
207  Values.reserve(4);
208  bool IsStrong = Type == CCT::StrongOrdering;
209  Values.push_back(IsStrong ? CCR::Equal : CCR::Equivalent);
210  Values.push_back(CCR::Less);
211  Values.push_back(CCR::Greater);
212  if (Type == CCT::PartialOrdering)
213    Values.push_back(CCR::Unordered);
214  return Values;
215}
216