DDC 0.10.0
Loading...
Searching...
No Matches
transform_reduce.hpp
1// Copyright (C) The DDC development team, see COPYRIGHT.md file
2//
3// SPDX-License-Identifier: MIT
4
5#pragma once
6
7#include <array>
8#include <cstddef>
9#include <utility>
10
11#include <Kokkos_Macros.hpp>
12
13#include "detail/macros.hpp"
14
15#include "discrete_vector.hpp"
16
17namespace ddc {
18
19namespace detail {
20
21/** A serial reduction over a nD domain
22 * @param[in] domain the range over which to apply the algorithm
23 * @param[in] neutral the neutral element of the reduction operation
24 * @param[in] reduce a binary FunctionObject that will be applied in unspecified order to the
25 * results of transform, the results of other reduce and neutral.
26 * @param[in] transform a unary FunctionObject that will be applied to each element of the input
27 * range. The return type must be acceptable as input to reduce
28 * @param[in] is indices from dimensions already in a loop
29 */
30template <
31 class Support,
32 std::size_t N,
33 class T,
34 class BinaryReductionOp,
35 class UnaryTransformOp,
36 class... Is>
37T host_transform_reduce_serial(
38 Support const& domain,
39 std::array<DiscreteVectorElement, N> const& size,
40 [[maybe_unused]] T const neutral,
41 BinaryReductionOp const& reduce,
42 UnaryTransformOp const& transform,
43 Is const&... is) noexcept
44{
45 DDC_IF_NVCC_THEN_PUSH_AND_SUPPRESS(implicit_return_from_non_void_function)
46 static constexpr std::size_t I = sizeof...(Is);
47 if constexpr (I == N) {
48 return transform(domain(typename Support::discrete_vector_type(is...)));
49 } else {
50 T result = neutral;
51 for (DiscreteVectorElement ii = 0; ii < size[I]; ++ii) {
52 result = reduce(
53 host_transform_reduce_serial(
54 domain,
55 size,
56 neutral,
57 reduce,
58 transform,
59 is...,
60 ii),
61 result);
62 }
63 return result;
64 }
65 DDC_IF_NVCC_THEN_POP
66}
67
68/** A serial reduction over a nD domain. Can be called from a device kernel.
69 * @param[in] domain the range over which to apply the algorithm
70 * @param[in] neutral the neutral element of the reduction operation
71 * @param[in] reduce a binary FunctionObject that will be applied in unspecified order to the
72 * results of transform, the results of other reduce and neutral.
73 * @param[in] transform a unary FunctionObject that will be applied to each element of the input
74 * range. The return type must be acceptable as input to reduce
75 * @param[in] is indices from dimensions already in a loop
76 */
77template <
78 class Support,
79 class T,
80 std::size_t N,
81 class BinaryReductionOp,
82 class UnaryTransformOp,
83 class... Is>
84KOKKOS_FUNCTION T device_transform_reduce_serial(
85 Support const& domain,
86 std::array<DiscreteVectorElement, N> const& size,
87 [[maybe_unused]] T const neutral,
88 BinaryReductionOp const& reduce,
89 UnaryTransformOp const& transform,
90 Is const&... is) noexcept
91{
92 DDC_IF_NVCC_THEN_PUSH_AND_SUPPRESS(implicit_return_from_non_void_function)
93 static constexpr std::size_t I = sizeof...(Is);
94 if constexpr (I == N) {
95 return transform(domain(typename Support::discrete_vector_type(is...)));
96 } else {
97 T result = neutral;
98 for (DiscreteVectorElement ii = 0; ii < size[I]; ++ii) {
99 result = reduce(
100 device_transform_reduce_serial(
101 domain,
102 size,
103 neutral,
104 reduce,
105 transform,
106 is...,
107 ii),
108 result);
109 }
110 return result;
111 }
112 DDC_IF_NVCC_THEN_POP
113}
114
115} // namespace detail
116
117#if defined(DDC_BUILD_DEPRECATED_CODE)
118/** A reduction over a nD domain in serial
119 * @param[in] domain the range over which to apply the algorithm
120 * @param[in] neutral the neutral element of the reduction operation
121 * @param[in] reduce a binary FunctionObject that will be applied in unspecified order to the
122 * results of transform, the results of other reduce and neutral.
123 * @param[in] transform a unary FunctionObject that will be applied to each element of the input
124 * range. The return type must be acceptable as input to reduce
125 */
126template <class Support, class T, class BinaryReductionOp, class UnaryTransformOp>
127[[deprecated("Use host_transform_reduce instead")]]
129 Support const& domain,
130 T neutral,
132 UnaryTransformOp&& transform) noexcept
133{
135 domain,
136 neutral,
139}
140#endif
141
142/** A reduction over a nD domain in serial
143 * @param[in] domain the range over which to apply the algorithm
144 * @param[in] neutral the neutral element of the reduction operation
145 * @param[in] reduce a binary FunctionObject that will be applied in unspecified order to the
146 * results of transform, the results of other reduce and neutral.
147 * @param[in] transform a unary FunctionObject that will be applied to each element of the input
148 * range. The return type must be acceptable as input to reduce
149 */
150template <class Support, class T, class BinaryReductionOp, class UnaryTransformOp>
152 Support const& domain,
153 T neutral,
154 BinaryReductionOp&& reduce,
155 UnaryTransformOp&& transform) noexcept
156{
157 return detail::host_transform_reduce_serial(
158 domain,
159 detail::array(domain.extents()),
160 neutral,
161 std::forward<BinaryReductionOp>(reduce),
162 std::forward<UnaryTransformOp>(transform));
163}
164
165/** A reduction over a nD domain in serial. Can be called from a device kernel.
166 * @param[in] domain the range over which to apply the algorithm
167 * @param[in] neutral the neutral element of the reduction operation
168 * @param[in] reduce a binary FunctionObject that will be applied in unspecified order to the
169 * results of transform, the results of other reduce and neutral.
170 * @param[in] transform a unary FunctionObject that will be applied to each element of the input
171 * range. The return type must be acceptable as input to reduce
172 */
173template <class Support, class T, class BinaryReductionOp, class UnaryTransformOp>
175 Support const& domain,
176 T neutral,
177 BinaryReductionOp&& reduce,
178 UnaryTransformOp&& transform) noexcept
179{
180 return detail::device_transform_reduce_serial(
181 domain,
182 detail::array(domain.extents()),
183 neutral,
184 std::forward<BinaryReductionOp>(reduce),
185 std::forward<UnaryTransformOp>(transform));
186}
187
188} // namespace ddc
The top-level namespace of DDC.
T host_transform_reduce(Support const &domain, T neutral, BinaryReductionOp &&reduce, UnaryTransformOp &&transform) noexcept
A reduction over a nD domain in serial.
KOKKOS_FUNCTION T device_transform_reduce(Support const &domain, T neutral, BinaryReductionOp &&reduce, UnaryTransformOp &&transform) noexcept
A reduction over a nD domain in serial.