DDC 0.10.0
Loading...
Searching...
No Matches
transform_reduce.hpp
1// Copyright (C) The DDC development team, see COPYRIGHT.md file
2//
3// SPDX-License-Identifier: MIT
4
5#pragma once
6
7#include <array>
8#include <cstddef>
9#include <utility>
10
11#include <Kokkos_Macros.hpp>
12
13#include "detail/macros.hpp"
14
15#include "discrete_element.hpp"
16
17namespace ddc {
18
19namespace detail {
20
21/** A serial reduction over a nD domain
22 * @param[in] domain the range over which to apply the algorithm
23 * @param[in] neutral the neutral element of the reduction operation
24 * @param[in] reduce a binary FunctionObject that will be applied in unspecified order to the
25 * results of transform, the results of other reduce and neutral.
26 * @param[in] transform a unary FunctionObject that will be applied to each element of the input
27 * range. The return type must be acceptable as input to reduce
28 * @param[in] dcoords discrete elements from dimensions already in a loop
29 */
30template <
31 class Support,
32 class Element,
33 std::size_t N,
34 class T,
35 class BinaryReductionOp,
36 class UnaryTransformOp,
37 class... Is>
38T host_transform_reduce_serial(
39 Support const& domain,
40 std::array<Element, N> const& size,
41 [[maybe_unused]] T const neutral,
42 BinaryReductionOp const& reduce,
43 UnaryTransformOp const& transform,
44 Is const&... is) noexcept
45{
46 DDC_IF_NVCC_THEN_PUSH_AND_SUPPRESS(implicit_return_from_non_void_function)
47 static constexpr std::size_t I = sizeof...(Is);
48 if constexpr (I == N) {
49 return transform(domain(typename Support::discrete_vector_type(is...)));
50 } else {
51 T result = neutral;
52 for (Element ii = 0; ii < size[I]; ++ii) {
53 result = reduce(
54 host_transform_reduce_serial(
55 domain,
56 size,
57 neutral,
58 reduce,
59 transform,
60 is...,
61 ii),
62 result);
63 }
64 return result;
65 }
66 DDC_IF_NVCC_THEN_POP
67}
68
69/** A serial reduction over a nD domain. Can be called from a device kernel.
70 * @param[in] domain the range over which to apply the algorithm
71 * @param[in] neutral the neutral element of the reduction operation
72 * @param[in] reduce a binary FunctionObject that will be applied in unspecified order to the
73 * results of transform, the results of other reduce and neutral.
74 * @param[in] transform a unary FunctionObject that will be applied to each element of the input
75 * range. The return type must be acceptable as input to reduce
76 * @param[in] dcoords discrete elements from dimensions already in a loop
77 */
78template <
79 class Support,
80 class T,
81 class Element,
82 std::size_t N,
83 class BinaryReductionOp,
84 class UnaryTransformOp,
85 class... Is>
86KOKKOS_FUNCTION T device_transform_reduce_serial(
87 Support const& domain,
88 std::array<Element, N> const& size,
89 [[maybe_unused]] T const neutral,
90 BinaryReductionOp const& reduce,
91 UnaryTransformOp const& transform,
92 Is const&... is) noexcept
93{
94 DDC_IF_NVCC_THEN_PUSH_AND_SUPPRESS(implicit_return_from_non_void_function)
95 static constexpr std::size_t I = sizeof...(Is);
96 if constexpr (I == N) {
97 return transform(domain(typename Support::discrete_vector_type(is...)));
98 } else {
99 T result = neutral;
100 for (Element ii = 0; ii < size[I]; ++ii) {
101 result = reduce(
102 device_transform_reduce_serial(
103 domain,
104 size,
105 neutral,
106 reduce,
107 transform,
108 is...,
109 ii),
110 result);
111 }
112 return result;
113 }
114 DDC_IF_NVCC_THEN_POP
115}
116
117} // namespace detail
118
119#if defined(DDC_BUILD_DEPRECATED_CODE)
120/** A reduction over a nD domain in serial
121 * @param[in] domain the range over which to apply the algorithm
122 * @param[in] neutral the neutral element of the reduction operation
123 * @param[in] reduce a binary FunctionObject that will be applied in unspecified order to the
124 * results of transform, the results of other reduce and neutral.
125 * @param[in] transform a unary FunctionObject that will be applied to each element of the input
126 * range. The return type must be acceptable as input to reduce
127 */
128template <class Support, class T, class BinaryReductionOp, class UnaryTransformOp>
129[[deprecated("Use host_transform_reduce instead")]]
131 Support const& domain,
132 T neutral,
134 UnaryTransformOp&& transform) noexcept
135{
137 domain,
138 neutral,
141}
142#endif
143
144/** A reduction over a nD domain in serial
145 * @param[in] domain the range over which to apply the algorithm
146 * @param[in] neutral the neutral element of the reduction operation
147 * @param[in] reduce a binary FunctionObject that will be applied in unspecified order to the
148 * results of transform, the results of other reduce and neutral.
149 * @param[in] transform a unary FunctionObject that will be applied to each element of the input
150 * range. The return type must be acceptable as input to reduce
151 */
152template <class Support, class T, class BinaryReductionOp, class UnaryTransformOp>
154 Support const& domain,
155 T neutral,
156 BinaryReductionOp&& reduce,
157 UnaryTransformOp&& transform) noexcept
158{
159 return detail::host_transform_reduce_serial(
160 domain,
161 detail::array(domain.extents()),
162 neutral,
163 std::forward<BinaryReductionOp>(reduce),
164 std::forward<UnaryTransformOp>(transform));
165}
166
167/** A reduction over a nD domain in serial. Can be called from a device kernel.
168 * @param[in] domain the range over which to apply the algorithm
169 * @param[in] neutral the neutral element of the reduction operation
170 * @param[in] reduce a binary FunctionObject that will be applied in unspecified order to the
171 * results of transform, the results of other reduce and neutral.
172 * @param[in] transform a unary FunctionObject that will be applied to each element of the input
173 * range. The return type must be acceptable as input to reduce
174 */
175template <class Support, class T, class BinaryReductionOp, class UnaryTransformOp>
177 Support const& domain,
178 T neutral,
179 BinaryReductionOp&& reduce,
180 UnaryTransformOp&& transform) noexcept
181{
182 return detail::device_transform_reduce_serial(
183 domain,
184 detail::array(domain.extents()),
185 neutral,
186 std::forward<BinaryReductionOp>(reduce),
187 std::forward<UnaryTransformOp>(transform));
188}
189
190} // namespace ddc
The top-level namespace of DDC.
T host_transform_reduce(Support const &domain, T neutral, BinaryReductionOp &&reduce, UnaryTransformOp &&transform) noexcept
A reduction over a nD domain in serial.
KOKKOS_FUNCTION T device_transform_reduce(Support const &domain, T neutral, BinaryReductionOp &&reduce, UnaryTransformOp &&transform) noexcept
A reduction over a nD domain in serial.