1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef _LIBCPP___ALGORITHM_PSTL_BACKENDS_CPU_BACKENDS_TRANSFORM_H
10#define _LIBCPP___ALGORITHM_PSTL_BACKENDS_CPU_BACKENDS_TRANSFORM_H
11
12#include <__algorithm/pstl_backends/cpu_backends/backend.h>
13#include <__algorithm/transform.h>
14#include <__config>
15#include <__iterator/concepts.h>
16#include <__iterator/iterator_traits.h>
17#include <__type_traits/enable_if.h>
18#include <__type_traits/is_execution_policy.h>
19#include <__type_traits/remove_cvref.h>
20#include <optional>
21
22#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
23#  pragma GCC system_header
24#endif
25
26#if !defined(_LIBCPP_HAS_NO_INCOMPLETE_PSTL) && _LIBCPP_STD_VER >= 17
27
28_LIBCPP_PUSH_MACROS
29#  include <__undef_macros>
30
31_LIBCPP_BEGIN_NAMESPACE_STD
32
33template <class _Iterator1, class _DifferenceType, class _Iterator2, class _Function>
34_LIBCPP_HIDE_FROM_ABI _Iterator2
35__simd_walk(_Iterator1 __first1, _DifferenceType __n, _Iterator2 __first2, _Function __f) noexcept {
36  _PSTL_PRAGMA_SIMD
37  for (_DifferenceType __i = 0; __i < __n; ++__i)
38    __f(__first1[__i], __first2[__i]);
39  return __first2 + __n;
40}
41
42template <class _ExecutionPolicy, class _ForwardIterator, class _ForwardOutIterator, class _UnaryOperation>
43_LIBCPP_HIDE_FROM_ABI optional<_ForwardOutIterator> __pstl_transform(
44    __cpu_backend_tag,
45    _ForwardIterator __first,
46    _ForwardIterator __last,
47    _ForwardOutIterator __result,
48    _UnaryOperation __op) {
49  if constexpr (__is_parallel_execution_policy_v<_ExecutionPolicy> &&
50                __has_random_access_iterator_category_or_concept<_ForwardIterator>::value &&
51                __has_random_access_iterator_category_or_concept<_ForwardOutIterator>::value) {
52    std::__par_backend::__parallel_for(
53        __first, __last, [__op, __first, __result](_ForwardIterator __brick_first, _ForwardIterator __brick_last) {
54          auto __res = std::__pstl_transform<__remove_parallel_policy_t<_ExecutionPolicy>>(
55              __cpu_backend_tag{}, __brick_first, __brick_last, __result + (__brick_first - __first), __op);
56          _LIBCPP_ASSERT_INTERNAL(__res, "unseq/seq should never try to allocate!");
57          return *std::move(__res);
58        });
59    return __result + (__last - __first);
60  } else if constexpr (__is_unsequenced_execution_policy_v<_ExecutionPolicy> &&
61                       __has_random_access_iterator_category_or_concept<_ForwardIterator>::value &&
62                       __has_random_access_iterator_category_or_concept<_ForwardOutIterator>::value) {
63    return std::__simd_walk(
64        __first,
65        __last - __first,
66        __result,
67        [&](__iter_reference<_ForwardIterator> __in_value, __iter_reference<_ForwardOutIterator> __out_value) {
68          __out_value = __op(__in_value);
69        });
70  } else {
71    return std::transform(__first, __last, __result, __op);
72  }
73}
74
75template <class _Iterator1, class _DifferenceType, class _Iterator2, class _Iterator3, class _Function>
76_LIBCPP_HIDE_FROM_ABI _Iterator3 __simd_walk(
77    _Iterator1 __first1, _DifferenceType __n, _Iterator2 __first2, _Iterator3 __first3, _Function __f) noexcept {
78  _PSTL_PRAGMA_SIMD
79  for (_DifferenceType __i = 0; __i < __n; ++__i)
80    __f(__first1[__i], __first2[__i], __first3[__i]);
81  return __first3 + __n;
82}
83template <class _ExecutionPolicy,
84          class _ForwardIterator1,
85          class _ForwardIterator2,
86          class _ForwardOutIterator,
87          class _BinaryOperation,
88          enable_if_t<is_execution_policy_v<__remove_cvref_t<_ExecutionPolicy>>, int> = 0>
89_LIBCPP_HIDE_FROM_ABI optional<_ForwardOutIterator> __pstl_transform(
90    __cpu_backend_tag,
91    _ForwardIterator1 __first1,
92    _ForwardIterator1 __last1,
93    _ForwardIterator2 __first2,
94    _ForwardOutIterator __result,
95    _BinaryOperation __op) {
96  if constexpr (__is_parallel_execution_policy_v<_ExecutionPolicy> &&
97                __has_random_access_iterator_category_or_concept<_ForwardIterator1>::value &&
98                __has_random_access_iterator_category_or_concept<_ForwardIterator2>::value &&
99                __has_random_access_iterator_category_or_concept<_ForwardOutIterator>::value) {
100    auto __res = std::__par_backend::__parallel_for(
101        __first1,
102        __last1,
103        [__op, __first1, __first2, __result](_ForwardIterator1 __brick_first, _ForwardIterator1 __brick_last) {
104          return std::__pstl_transform<__remove_parallel_policy_t<_ExecutionPolicy>>(
105              __cpu_backend_tag{},
106              __brick_first,
107              __brick_last,
108              __first2 + (__brick_first - __first1),
109              __result + (__brick_first - __first1),
110              __op);
111        });
112    if (!__res)
113      return nullopt;
114    return __result + (__last1 - __first1);
115  } else if constexpr (__is_unsequenced_execution_policy_v<_ExecutionPolicy> &&
116                       __has_random_access_iterator_category_or_concept<_ForwardIterator1>::value &&
117                       __has_random_access_iterator_category_or_concept<_ForwardIterator2>::value &&
118                       __has_random_access_iterator_category_or_concept<_ForwardOutIterator>::value) {
119    return std::__simd_walk(
120        __first1,
121        __last1 - __first1,
122        __first2,
123        __result,
124        [&](__iter_reference<_ForwardIterator1> __in1,
125            __iter_reference<_ForwardIterator2> __in2,
126            __iter_reference<_ForwardOutIterator> __out_value) { __out_value = __op(__in1, __in2); });
127  } else {
128    return std::transform(__first1, __last1, __first2, __result, __op);
129  }
130}
131
132_LIBCPP_END_NAMESPACE_STD
133
134_LIBCPP_POP_MACROS
135
136#endif // !defined(_LIBCPP_HAS_NO_INCOMPLETE_PSTL) && _LIBCPP_STD_VER >= 17
137
138#endif // _LIBCPP___ALGORITHM_PSTL_BACKENDS_CPU_BACKENDS_TRANSFORM_H
139