DDC 0.10.0
Loading...
Searching...
No Matches
transform_reduce.hpp
1// Copyright (C) The DDC development team, see COPYRIGHT.md file
2//
3// SPDX-License-Identifier: MIT
4
5#pragma once
6
7#include <utility>
8
9#include "detail/macros.hpp"
10
11#include "discrete_element.hpp"
12
13namespace ddc {
14
15namespace detail {
16
17/** A serial reduction over a nD domain
18 * @param[in] domain the range over which to apply the algorithm
19 * @param[in] neutral the neutral element of the reduction operation
20 * @param[in] reduce a binary FunctionObject that will be applied in unspecified order to the
21 * results of transform, the results of other reduce and neutral.
22 * @param[in] transform a unary FunctionObject that will be applied to each element of the input
23 * range. The return type must be acceptable as input to reduce
24 * @param[in] dcoords discrete elements from dimensions already in a loop
25 */
26template <
27 class Support,
28 class Element,
29 std::size_t N,
30 class T,
31 class BinaryReductionOp,
32 class UnaryTransformOp,
33 class... Is>
34T host_transform_reduce_serial(
35 Support const& domain,
36 std::array<Element, N> const& size,
37 [[maybe_unused]] T const neutral,
38 BinaryReductionOp const& reduce,
39 UnaryTransformOp const& transform,
40 Is const&... is) noexcept
41{
42 DDC_IF_NVCC_THEN_PUSH_AND_SUPPRESS(implicit_return_from_non_void_function)
43 static constexpr std::size_t I = sizeof...(Is);
44 if constexpr (I == N) {
45 return transform(domain(typename Support::discrete_vector_type(is...)));
46 } else {
47 T result = neutral;
48 for (Element ii = 0; ii < size[I]; ++ii) {
49 result = reduce(
50 host_transform_reduce_serial(
51 domain,
52 size,
53 neutral,
54 reduce,
55 transform,
56 is...,
57 ii),
58 result);
59 }
60 return result;
61 }
62 DDC_IF_NVCC_THEN_POP
63}
64
65/** A serial reduction over a nD domain. Can be called from a device kernel.
66 * @param[in] domain the range over which to apply the algorithm
67 * @param[in] neutral the neutral element of the reduction operation
68 * @param[in] reduce a binary FunctionObject that will be applied in unspecified order to the
69 * results of transform, the results of other reduce and neutral.
70 * @param[in] transform a unary FunctionObject that will be applied to each element of the input
71 * range. The return type must be acceptable as input to reduce
72 * @param[in] dcoords discrete elements from dimensions already in a loop
73 */
74template <
75 class Support,
76 class T,
77 class Element,
78 std::size_t N,
79 class BinaryReductionOp,
80 class UnaryTransformOp,
81 class... Is>
82KOKKOS_FUNCTION T device_transform_reduce_serial(
83 Support const& domain,
84 std::array<Element, N> const& size,
85 [[maybe_unused]] T const neutral,
86 BinaryReductionOp const& reduce,
87 UnaryTransformOp const& transform,
88 Is const&... is) noexcept
89{
90 DDC_IF_NVCC_THEN_PUSH_AND_SUPPRESS(implicit_return_from_non_void_function)
91 static constexpr std::size_t I = sizeof...(Is);
92 if constexpr (I == N) {
93 return transform(domain(typename Support::discrete_vector_type(is...)));
94 } else {
95 T result = neutral;
96 for (Element ii = 0; ii < size[I]; ++ii) {
97 result = reduce(
98 device_transform_reduce_serial(
99 domain,
100 size,
101 neutral,
102 reduce,
103 transform,
104 is...,
105 ii),
106 result);
107 }
108 return result;
109 }
110 DDC_IF_NVCC_THEN_POP
111}
112
113} // namespace detail
114
115#if defined(DDC_BUILD_DEPRECATED_CODE)
116/** A reduction over a nD domain in serial
117 * @param[in] domain the range over which to apply the algorithm
118 * @param[in] neutral the neutral element of the reduction operation
119 * @param[in] reduce a binary FunctionObject that will be applied in unspecified order to the
120 * results of transform, the results of other reduce and neutral.
121 * @param[in] transform a unary FunctionObject that will be applied to each element of the input
122 * range. The return type must be acceptable as input to reduce
123 */
124template <class Support, class T, class BinaryReductionOp, class UnaryTransformOp>
125[[deprecated("Use host_transform_reduce instead")]]
127 Support const& domain,
128 T neutral,
130 UnaryTransformOp&& transform) noexcept
131{
133 domain,
134 neutral,
137}
138#endif
139
140/** A reduction over a nD domain in serial
141 * @param[in] domain the range over which to apply the algorithm
142 * @param[in] neutral the neutral element of the reduction operation
143 * @param[in] reduce a binary FunctionObject that will be applied in unspecified order to the
144 * results of transform, the results of other reduce and neutral.
145 * @param[in] transform a unary FunctionObject that will be applied to each element of the input
146 * range. The return type must be acceptable as input to reduce
147 */
148template <class Support, class T, class BinaryReductionOp, class UnaryTransformOp>
150 Support const& domain,
151 T neutral,
152 BinaryReductionOp&& reduce,
153 UnaryTransformOp&& transform) noexcept
154{
155 return detail::host_transform_reduce_serial(
156 domain,
157 detail::array(domain.extents()),
158 neutral,
159 std::forward<BinaryReductionOp>(reduce),
160 std::forward<UnaryTransformOp>(transform));
161}
162
163/** A reduction over a nD domain in serial. Can be called from a device kernel.
164 * @param[in] domain the range over which to apply the algorithm
165 * @param[in] neutral the neutral element of the reduction operation
166 * @param[in] reduce a binary FunctionObject that will be applied in unspecified order to the
167 * results of transform, the results of other reduce and neutral.
168 * @param[in] transform a unary FunctionObject that will be applied to each element of the input
169 * range. The return type must be acceptable as input to reduce
170 */
171template <class Support, class T, class BinaryReductionOp, class UnaryTransformOp>
173 Support const& domain,
174 T neutral,
175 BinaryReductionOp&& reduce,
176 UnaryTransformOp&& transform) noexcept
177{
178 return detail::device_transform_reduce_serial(
179 domain,
180 detail::array(domain.extents()),
181 neutral,
182 std::forward<BinaryReductionOp>(reduce),
183 std::forward<UnaryTransformOp>(transform));
184}
185
186} // namespace ddc
The top-level namespace of DDC.
T host_transform_reduce(Support const &domain, T neutral, BinaryReductionOp &&reduce, UnaryTransformOp &&transform) noexcept
A reduction over a nD domain in serial.
KOKKOS_FUNCTION T device_transform_reduce(Support const &domain, T neutral, BinaryReductionOp &&reduce, UnaryTransformOp &&transform) noexcept
A reduction over a nD domain in serial.