1//===- TypeSwitch.h - Switch functionality for RTTI casting -*- C++ -*-----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10///  This file implements the TypeSwitch template, which mimics a switch()
11///  statement whose cases are type names.
12///
13//===-----------------------------------------------------------------------===/
14
15#ifndef LLVM_ADT_TYPESWITCH_H
16#define LLVM_ADT_TYPESWITCH_H
17
18#include "llvm/ADT/STLExtras.h"
19#include "llvm/Support/Casting.h"
20#include <optional>
21
22namespace llvm {
23namespace detail {
24
25template <typename DerivedT, typename T> class TypeSwitchBase {
26public:
27  TypeSwitchBase(const T &value) : value(value) {}
28  TypeSwitchBase(TypeSwitchBase &&other) : value(other.value) {}
29  ~TypeSwitchBase() = default;
30
31  /// TypeSwitchBase is not copyable.
32  TypeSwitchBase(const TypeSwitchBase &) = delete;
33  void operator=(const TypeSwitchBase &) = delete;
34  void operator=(TypeSwitchBase &&other) = delete;
35
36  /// Invoke a case on the derived class with multiple case types.
37  template <typename CaseT, typename CaseT2, typename... CaseTs,
38            typename CallableT>
39  // This is marked always_inline and nodebug so it doesn't show up in stack
40  // traces at -O0 (or other optimization levels).  Large TypeSwitch's are
41  // common, are equivalent to a switch, and don't add any value to stack
42  // traces.
43  LLVM_ATTRIBUTE_ALWAYS_INLINE LLVM_ATTRIBUTE_NODEBUG DerivedT &
44  Case(CallableT &&caseFn) {
45    DerivedT &derived = static_cast<DerivedT &>(*this);
46    return derived.template Case<CaseT>(caseFn)
47        .template Case<CaseT2, CaseTs...>(caseFn);
48  }
49
50  /// Invoke a case on the derived class, inferring the type of the Case from
51  /// the first input of the given callable.
52  /// Note: This inference rules for this overload are very simple: strip
53  ///       pointers and references.
54  template <typename CallableT> DerivedT &Case(CallableT &&caseFn) {
55    using Traits = function_traits<std::decay_t<CallableT>>;
56    using CaseT = std::remove_cv_t<std::remove_pointer_t<
57        std::remove_reference_t<typename Traits::template arg_t<0>>>>;
58
59    DerivedT &derived = static_cast<DerivedT &>(*this);
60    return derived.template Case<CaseT>(std::forward<CallableT>(caseFn));
61  }
62
63protected:
64  /// Trait to check whether `ValueT` provides a 'dyn_cast' method with type
65  /// `CastT`.
66  template <typename ValueT, typename CastT>
67  using has_dyn_cast_t =
68      decltype(std::declval<ValueT &>().template dyn_cast<CastT>());
69
70  /// Attempt to dyn_cast the given `value` to `CastT`. This overload is
71  /// selected if `value` already has a suitable dyn_cast method.
72  template <typename CastT, typename ValueT>
73  static decltype(auto) castValue(
74      ValueT &&value,
75      std::enable_if_t<is_detected<has_dyn_cast_t, ValueT, CastT>::value> * =
76          nullptr) {
77    return value.template dyn_cast<CastT>();
78  }
79
80  /// Attempt to dyn_cast the given `value` to `CastT`. This overload is
81  /// selected if llvm::dyn_cast should be used.
82  template <typename CastT, typename ValueT>
83  static decltype(auto) castValue(
84      ValueT &&value,
85      std::enable_if_t<!is_detected<has_dyn_cast_t, ValueT, CastT>::value> * =
86          nullptr) {
87    return dyn_cast<CastT>(value);
88  }
89
90  /// The root value we are switching on.
91  const T value;
92};
93} // end namespace detail
94
95/// This class implements a switch-like dispatch statement for a value of 'T'
96/// using dyn_cast functionality. Each `Case<T>` takes a callable to be invoked
97/// if the root value isa<T>, the callable is invoked with the result of
98/// dyn_cast<T>() as a parameter.
99///
100/// Example:
101///  Operation *op = ...;
102///  LogicalResult result = TypeSwitch<Operation *, LogicalResult>(op)
103///    .Case<ConstantOp>([](ConstantOp op) { ... })
104///    .Default([](Operation *op) { ... });
105///
106template <typename T, typename ResultT = void>
107class TypeSwitch : public detail::TypeSwitchBase<TypeSwitch<T, ResultT>, T> {
108public:
109  using BaseT = detail::TypeSwitchBase<TypeSwitch<T, ResultT>, T>;
110  using BaseT::BaseT;
111  using BaseT::Case;
112  TypeSwitch(TypeSwitch &&other) = default;
113
114  /// Add a case on the given type.
115  template <typename CaseT, typename CallableT>
116  TypeSwitch<T, ResultT> &Case(CallableT &&caseFn) {
117    if (result)
118      return *this;
119
120    // Check to see if CaseT applies to 'value'.
121    if (auto caseValue = BaseT::template castValue<CaseT>(this->value))
122      result.emplace(caseFn(caseValue));
123    return *this;
124  }
125
126  /// As a default, invoke the given callable within the root value.
127  template <typename CallableT>
128  [[nodiscard]] ResultT Default(CallableT &&defaultFn) {
129    if (result)
130      return std::move(*result);
131    return defaultFn(this->value);
132  }
133  /// As a default, return the given value.
134  [[nodiscard]] ResultT Default(ResultT defaultResult) {
135    if (result)
136      return std::move(*result);
137    return defaultResult;
138  }
139
140  [[nodiscard]] operator ResultT() {
141    assert(result && "Fell off the end of a type-switch");
142    return std::move(*result);
143  }
144
145private:
146  /// The pointer to the result of this switch statement, once known,
147  /// null before that.
148  std::optional<ResultT> result;
149};
150
151/// Specialization of TypeSwitch for void returning callables.
152template <typename T>
153class TypeSwitch<T, void>
154    : public detail::TypeSwitchBase<TypeSwitch<T, void>, T> {
155public:
156  using BaseT = detail::TypeSwitchBase<TypeSwitch<T, void>, T>;
157  using BaseT::BaseT;
158  using BaseT::Case;
159  TypeSwitch(TypeSwitch &&other) = default;
160
161  /// Add a case on the given type.
162  template <typename CaseT, typename CallableT>
163  TypeSwitch<T, void> &Case(CallableT &&caseFn) {
164    if (foundMatch)
165      return *this;
166
167    // Check to see if any of the types apply to 'value'.
168    if (auto caseValue = BaseT::template castValue<CaseT>(this->value)) {
169      caseFn(caseValue);
170      foundMatch = true;
171    }
172    return *this;
173  }
174
175  /// As a default, invoke the given callable within the root value.
176  template <typename CallableT> void Default(CallableT &&defaultFn) {
177    if (!foundMatch)
178      defaultFn(this->value);
179  }
180
181private:
182  /// A flag detailing if we have already found a match.
183  bool foundMatch = false;
184};
185} // end namespace llvm
186
187#endif // LLVM_ADT_TYPESWITCH_H
188