DDC 0.14.0
Loading...
Searching...
No Matches
fft.hpp
1// Copyright (C) The DDC development team, see COPYRIGHT.md file
2//
3// SPDX-License-Identifier: MIT
4
5#pragma once
6
7#include <cassert>
8#include <type_traits>
9#include <utility>
10
11#include <ddc/ddc.hpp>
12
13#include <KokkosFFT.hpp>
14#include <Kokkos_Core.hpp>
15
16namespace ddc {
17
18/**
19 * @brief A templated tag representing a continuous dimension in the Fourier space associated to the original continuous dimension.
20 *
21 * @tparam The tag representing the original dimension.
22 */
23template <typename Dim>
24struct Fourier
25{
26};
27
28/**
29 * @brief A named argument to choose the direction of the FFT.
30 *
31 * @see kwArgsImpl, kwArgs_fft
32 */
33enum class FFT_Direction {
34 FORWARD, ///< Forward, corresponds to direct FFT up to normalization
35 BACKWARD ///< Backward, corresponds to inverse FFT up to normalization
36};
37
38/**
39 * @brief A named argument to choose the type of normalization of the FFT.
40 *
41 * @see kwArgsImpl, kwArgs_fft
42 */
43enum class FFT_Normalization {
44 OFF, ///< No normalization. Un-normalized FFT is sum_j f(x_j)*e^-ikx_j
45 FORWARD, ///< Multiply by 1/N for forward FFT, no normalization for backward FFT
46 BACKWARD, ///< No normalization for forward FFT, multiply by 1/N for backward FFT
47 ORTHO, ///< Multiply by 1/sqrt(N)
48 FULL /**<
49 * Multiply by dx/sqrt(2*pi) for forward FFT and dk/sqrt(2*pi) for backward
50 * FFT. It is aligned with the usual definition of the (continuous) Fourier transform
51 * 1/sqrt(2*pi)*int f(x)*e^-ikx*dx, and thus may be relevant for spectral analysis applications.
52 */
53};
54
55} // namespace ddc
56
57namespace ddc::detail::fft {
58
59template <typename T>
60struct RealType
61{
62 using type = T;
63};
64
65template <typename T>
66struct RealType<Kokkos::complex<T>>
67{
68 using type = T;
69};
70
71template <typename T>
72using real_type_t = RealType<T>::type;
73
74// is_complex : trait to determine if type is Kokkos::complex<something>
75template <typename T>
76struct is_complex : std::false_type
77{
78};
79
80template <typename T>
81struct is_complex<Kokkos::complex<T>> : std::true_type
82{
83};
84
85template <typename T>
86constexpr bool is_complex_v = is_complex<T>::value;
87
88/*
89 * @brief A structure embedding the configuration of the impl FFT function: direction and type of normalization.
90 *
91 * @see FFT_impl
92 */
93struct KwArgsImpl
94{
96 direction; // Only effective for C2C transform and for normalization BACKWARD and FORWARD
97 ddc::FFT_Normalization normalization;
98};
99
100template <typename... DDimX>
101KokkosFFT::axis_type<sizeof...(DDimX)> axes()
102{
103 return KokkosFFT::axis_type<sizeof...(DDimX)> {
104 static_cast<int>(ddc::type_seq_rank_v<DDimX, ddc::detail::TypeSeq<DDimX...>>)...};
105}
106
107KokkosFFT::Normalization ddc_fft_normalization_to_kokkos_fft(
108 FFT_Normalization ddc_fft_normalization);
109
110template <class T>
111class ScaleFn
112{
113 T m_coef;
114
115public:
116 explicit ScaleFn(T coef) noexcept : m_coef(std::move(coef)) {}
117
118 template <class U>
119 [[nodiscard]] KOKKOS_FUNCTION U operator()(U const& value) const noexcept
120 {
121 return m_coef * value;
122 }
123};
124
125template <class DDim>
126Real forward_full_norm_coef(DiscreteDomain<DDim> const& ddom) noexcept
127{
128 return rlength(ddom) / Kokkos::sqrt(2 * Kokkos::numbers::pi_v<Real>)
129 / (ddom.extents() - 1).value();
130}
131
132template <class DDim>
133Real backward_full_norm_coef(DiscreteDomain<DDim> const& ddom) noexcept
134{
135 return 1 / (forward_full_norm_coef(ddom) * ddom.extents().value());
136}
137
138/// @brief Core internal function to perform the FFT.
139template <
140 typename Tin,
141 typename Tout,
142 typename ExecSpace,
143 typename MemorySpace,
144 typename LayoutIn,
145 typename LayoutOut,
146 typename... DDimIn,
147 typename... DDimOut>
148void impl(
149 ExecSpace const& exec_space,
150 ddc::ChunkSpan<Tin, ddc::DiscreteDomain<DDimIn...>, LayoutIn, MemorySpace> const& in,
151 ddc::ChunkSpan<Tout, ddc::DiscreteDomain<DDimOut...>, LayoutOut, MemorySpace> const& out,
152 KwArgsImpl const& kwargs)
153{
154 static_assert(
155 std::is_same_v<real_type_t<Tin>, float> || std::is_same_v<real_type_t<Tin>, double>,
156 "Base type of Tin (and Tout) must be float or double.");
157 static_assert(
158 std::is_same_v<real_type_t<Tin>, real_type_t<Tout>>,
159 "Types Tin and Tout must be based on same type (float or double)");
160 static_assert(
161 Kokkos::SpaceAccessibility<ExecSpace, MemorySpace>::accessible,
162 "MemorySpace has to be accessible for ExecutionSpace.");
163
164 Kokkos::View<
165 ddc::detail::mdspan_to_kokkos_element_t<Tin, sizeof...(DDimIn)>,
166 ddc::detail::mdspan_to_kokkos_layout_t<LayoutIn>,
167 MemorySpace> const in_view
168 = in.allocation_kokkos_view();
169 Kokkos::View<
170 ddc::detail::mdspan_to_kokkos_element_t<Tout, sizeof...(DDimIn)>,
171 ddc::detail::mdspan_to_kokkos_layout_t<LayoutOut>,
172 MemorySpace> const out_view
173 = out.allocation_kokkos_view();
174 KokkosFFT::Normalization const kokkos_fft_normalization
175 = ddc_fft_normalization_to_kokkos_fft(kwargs.normalization);
176
177 // C2C
178 if constexpr (std::is_same_v<Tin, Tout>) {
179 if (kwargs.direction == ddc::FFT_Direction::FORWARD) {
180 KokkosFFT::
181 fftn(exec_space,
182 in_view,
183 out_view,
184 axes<DDimIn...>(),
185 kokkos_fft_normalization);
186 } else {
187 KokkosFFT::
188 ifftn(exec_space,
189 in_view,
190 out_view,
191 axes<DDimIn...>(),
192 kokkos_fft_normalization);
193 }
194 // R2C & C2R
195 } else {
196 if constexpr (is_complex_v<Tout>) {
197 assert(kwargs.direction == ddc::FFT_Direction::FORWARD);
198 KokkosFFT::
199 rfftn(exec_space,
200 in_view,
201 out_view,
202 axes<DDimIn...>(),
203 kokkos_fft_normalization);
204 } else {
205 assert(kwargs.direction == ddc::FFT_Direction::BACKWARD);
206 KokkosFFT::
207 irfftn(exec_space,
208 in_view,
209 out_view,
210 axes<DDimIn...>(),
211 kokkos_fft_normalization);
212 }
213 }
214
215 // The FULL normalization is mesh-dependant and thus handled by DDC
216 if (kwargs.normalization == ddc::FFT_Normalization::FULL) {
217 Real norm_coef;
218 if (kwargs.direction == ddc::FFT_Direction::FORWARD) {
219 DiscreteDomain<DDimIn...> const ddom_in = in.domain();
220 norm_coef = (forward_full_norm_coef(DiscreteDomain<DDimIn>(ddom_in)) * ...);
221 } else {
222 DiscreteDomain<DDimOut...> const ddom_out = out.domain();
223 norm_coef = (backward_full_norm_coef(DiscreteDomain<DDimOut>(ddom_out)) * ...);
224 }
225
226 ddc::parallel_transform("ddc_fft", exec_space, out, ScaleFn<real_type_t<Tout>>(norm_coef));
227 }
228}
229
230} // namespace ddc::detail::fft
231
232namespace ddc {
233
234/**
235 * @brief Initialize a Fourier discrete dimension.
236 *
237 * Initialize the (1D) discrete space representing the Fourier discrete dimension associated
238 * to the (1D) mesh passed as argument. It is a N-periodic PeriodicSampling with a periodic window of width 2*pi/dx.
239 *
240 * This value comes from the Nyquist-Shannon theorem: the period of the spectral domain is N*dk = 2*pi/dx.
241 * Adding to this the relations dx = (xmax-xmin)/(N-1), and dk = (kmax-kmin)/(N-1), we get kmax-kmin = 2*pi*(N-1)^2/N/(xmax-xmin),
242 * which is used in the implementation (xmax, xmin, kmin and kmax are the centers of lower and upper cells inside a single period of the meshes).
243 *
244 * @tparam DDimFx A PeriodicSampling representing the Fourier discrete dimension.
245 * @tparam DDimX The type of the original discrete dimension.
246 *
247 * @param x_mesh The DiscreteDomain representing the (1D) original mesh.
248 *
249 * @return The initialized Impl representing the discrete Fourier space.
250 *
251 * @see PeriodicSampling
252 */
253template <typename DDimFx, typename DDimX>
254DDimFx::template Impl<DDimFx, Kokkos::HostSpace> init_fourier_space(
255 ddc::DiscreteDomain<DDimX> x_mesh)
256{
257 static_assert(
258 is_uniform_point_sampling_v<DDimX>,
259 "DDimX dimension must derive from UniformPointSampling");
260 static_assert(
261 is_periodic_sampling_v<DDimFx>,
262 "DDimFx dimension must derive from PeriodicSampling");
263 using CDimFx = DDimFx::continuous_dimension_type;
264 using CDimX = DDimX::continuous_dimension_type;
265 static_assert(
266 std::is_same_v<CDimFx, ddc::Fourier<CDimX>>,
267 "DDimX and DDimFx dimensions must be defined over the same continuous dimension");
268
269 DiscreteVectorElement const nx = get<DDimX>(x_mesh.extents());
270 double const lx = ddc::rlength(x_mesh);
271 auto [impl, ddom] = DDimFx::template init<DDimFx>(
272 ddc::Coordinate<CDimFx>(0),
273 ddc::Coordinate<CDimFx>(2 * (nx - 1) * (nx - 1) / (nx * lx) * Kokkos::numbers::pi),
274 ddc::DiscreteVector<DDimFx>(nx),
275 ddc::DiscreteVector<DDimFx>(nx));
276 return std::move(impl);
277}
278
279/**
280 * @brief Get the Fourier mesh.
281 *
282 * Compute the Fourier (or spectral) mesh on which the Discrete Fourier Transform of a
283 * discrete function is defined.
284 *
285 * @param x_mesh The DiscreteDomain representing the original mesh.
286 * @param C2C A flag indicating if a complex-to-complex DFT is going to be performed. Indeed,
287 * in this case the two meshes have same number of points, whereas for real-to-complex
288 * or complex-to-real DFT, each complex value of the Fourier-transformed function contains twice more
289 * information, and thus only half (actually Nx*Ny*(Nz/2+1) for 3D R2C FFT to take in account mode 0)
290 * values are needed (cf. DFT conjugate symmetry property for more information about this).
291 *
292 * @return The domain representing the Fourier mesh.
293 */
294template <typename... DDimFx, typename... DDimX>
295ddc::DiscreteDomain<DDimFx...> fourier_mesh(ddc::DiscreteDomain<DDimX...> x_mesh, bool C2C)
296{
297 static_assert(
298 (is_uniform_point_sampling_v<DDimX> && ...),
299 "DDimX dimensions should derive from UniformPointSampling");
300 static_assert(
301 (is_periodic_sampling_v<DDimFx> && ...),
302 "DDimFx dimensions should derive from PeriodicPointSampling");
303 ddc::DiscreteVector<DDimX...> extents = x_mesh.extents();
304 if (!C2C) {
305 detail::array(extents).back() = detail::array(extents).back() / 2 + 1;
306 }
307 return ddc::DiscreteDomain<DDimFx...>(ddc::DiscreteDomain<DDimFx>(
308 ddc::DiscreteElement<DDimFx>(0),
309 ddc::DiscreteVector<DDimFx>(get<DDimX>(extents)))...);
310}
311
312/**
313 * @brief A structure embedding the configuration of the exposed FFT function with the type of normalization.
314 *
315 * @see fft, ifft
316 */
317struct kwArgs_fft
318{
320 normalization; ///< Enum member to identify the type of normalization performed
321};
322
323/**
324 * @brief Perform a direct Fast Fourier Transform.
325 *
326 * Compute the discrete Fourier transform of a function using the specialized implementation for the Kokkos::ExecutionSpace
327 * of the FFT algorithm.
328 *
329 * @tparam Tin The type of the input elements (float, Kokkos::complex<float>, double or Kokkos::complex<double>).
330 * @tparam Tout The type of the output elements (Kokkos::complex<float> or Kokkos::complex<double>).
331 * @tparam DDimFx... The parameter pack of the Fourier discrete dimensions.
332 * @tparam DDimX... The parameter pack of the original discrete dimensions.
333 * @tparam ExecSpace The type of the Kokkos::ExecutionSpace on which the FFT is performed. It determines which specialized
334 * backend is used (ie. fftw, cuFFT...).
335 * @tparam MemorySpace The type of the Kokkos::MemorySpace on which are stored the input and output discrete functions.
336 * @tparam LayoutIn The layout of the Chunkspan representing the input discrete function.
337 * @tparam LayoutOut The layout of the Chunkspan representing the output discrete function.
338 *
339 * @param exec_space The Kokkos::ExecutionSpace on which the FFT is performed.
340 * @param out The output discrete function, represented as a ChunkSpan storing values on a spectral mesh.
341 * @param in The input discrete function, represented as a ChunkSpan storing values on a mesh.
342 * @param kwargs The kwArgs_fft configuring the FFT.
343 */
344template <
345 typename Tin,
346 typename Tout,
347 typename... DDimFx,
348 typename... DDimX,
349 typename ExecSpace,
350 typename MemorySpace,
351 typename LayoutIn,
352 typename LayoutOut>
353void fft(
354 ExecSpace const& exec_space,
355 ddc::ChunkSpan<Tout, ddc::DiscreteDomain<DDimFx...>, LayoutOut, MemorySpace> out,
356 ddc::ChunkSpan<Tin, ddc::DiscreteDomain<DDimX...>, LayoutIn, MemorySpace> in,
358{
359 static_assert(
360 std::is_same_v<LayoutIn, Kokkos::layout_right>
361 && std::is_same_v<LayoutOut, Kokkos::layout_right>,
362 "Layouts must be right-handed");
363 static_assert(
364 (is_uniform_point_sampling_v<DDimX> && ...),
365 "DDimX dimensions should derive from UniformPointSampling");
366 static_assert(
367 (is_periodic_sampling_v<DDimFx> && ...),
368 "DDimFx dimensions should derive from PeriodicPointSampling");
369
370 ddc::detail::fft::
371 impl(exec_space, in, out, {ddc::FFT_Direction::FORWARD, kwargs.normalization});
372}
373
374/**
375 * @brief Perform an inverse Fast Fourier Transform.
376 *
377 * Compute the inverse discrete Fourier transform of a spectral function using the specialized implementation for the Kokkos::ExecutionSpace
378 * of the iFFT algorithm.
379 *
380 * @warning C2R iFFT does NOT preserve input.
381 *
382 * @tparam Tin The type of the input elements (Kokkos::complex<float> or Kokkos::complex<double>).
383 * @tparam Tout The type of the output elements (float, Kokkos::complex<float>, double or Kokkos::complex<double>).
384 * @tparam DDimX... The parameter pack of the original discrete dimensions.
385 * @tparam DDimFx... The parameter pack of the Fourier discrete dimensions.
386 * @tparam ExecSpace The type of the Kokkos::ExecutionSpace on which the iFFT is performed. It determines which specialized
387 * backend is used (ie. fftw, cuFFT...).
388 * @tparam MemorySpace The type of the Kokkos::MemorySpace on which are stored the input and output discrete functions.
389 * @tparam LayoutIn The layout of the Chunkspan representing the input discrete function.
390 * @tparam LayoutOut The layout of the Chunkspan representing the output discrete function.
391 *
392 * @param exec_space The Kokkos::ExecutionSpace on which the iFFT is performed.
393 * @param out The output discrete function, represented as a ChunkSpan storing values on a mesh.
394 * @param in The input discrete function, represented as a ChunkSpan storing values on a spectral mesh.
395 * @param kwargs The kwArgs_fft configuring the iFFT.
396 */
397template <
398 typename Tin,
399 typename Tout,
400 typename... DDimX,
401 typename... DDimFx,
402 typename ExecSpace,
403 typename MemorySpace,
404 typename LayoutIn,
405 typename LayoutOut>
406void ifft(
407 ExecSpace const& exec_space,
408 ddc::ChunkSpan<Tout, ddc::DiscreteDomain<DDimX...>, LayoutOut, MemorySpace> out,
409 ddc::ChunkSpan<Tin, ddc::DiscreteDomain<DDimFx...>, LayoutIn, MemorySpace> in,
411{
412 static_assert(
413 std::is_same_v<LayoutIn, Kokkos::layout_right>
414 && std::is_same_v<LayoutOut, Kokkos::layout_right>,
415 "Layouts must be right-handed");
416 static_assert(
417 (is_uniform_point_sampling_v<DDimX> && ...),
418 "DDimX dimensions should derive from UniformPointSampling");
419 static_assert(
420 (is_periodic_sampling_v<DDimFx> && ...),
421 "DDimFx dimensions should derive from PeriodicPointSampling");
422
423 ddc::detail::fft::
424 impl(exec_space, in, out, {ddc::FFT_Direction::BACKWARD, kwargs.normalization});
425}
426
427} // namespace ddc
friend class ChunkSpan
friend class DiscreteDomain
KOKKOS_FUNCTION constexpr bool operator!=(DiscreteVector< OTags... > const &rhs) const noexcept
The top-level namespace of DDC.
ddc::FFT_Normalization normalization
Enum member to identify the type of normalization performed.
Definition fft.hpp:320
void ifft(ExecSpace const &exec_space, ddc::ChunkSpan< Tout, ddc::DiscreteDomain< DDimX... >, LayoutOut, MemorySpace > out, ddc::ChunkSpan< Tin, ddc::DiscreteDomain< DDimFx... >, LayoutIn, MemorySpace > in, ddc::kwArgs_fft kwargs={ddc::FFT_Normalization::OFF})
Perform an inverse Fast Fourier Transform.
Definition fft.hpp:406
void fft(ExecSpace const &exec_space, ddc::ChunkSpan< Tout, ddc::DiscreteDomain< DDimFx... >, LayoutOut, MemorySpace > out, ddc::ChunkSpan< Tin, ddc::DiscreteDomain< DDimX... >, LayoutIn, MemorySpace > in, ddc::kwArgs_fft kwargs={ddc::FFT_Normalization::OFF})
Perform a direct Fast Fourier Transform.
Definition fft.hpp:353
FFT_Normalization
A named argument to choose the type of normalization of the FFT.
Definition fft.hpp:43
@ BACKWARD
No normalization for forward FFT, multiply by 1/N for backward FFT.
@ OFF
No normalization. Un-normalized FFT is sum_j f(x_j)*e^-ikx_j.
@ ORTHO
Multiply by 1/sqrt(N)
@ FULL
Multiply by dx/sqrt(2*pi) for forward FFT and dk/sqrt(2*pi) for backward FFT.
@ FORWARD
Multiply by 1/N for forward FFT, no normalization for backward FFT.
ddc::DiscreteDomain< DDimFx... > fourier_mesh(ddc::DiscreteDomain< DDimX... > x_mesh, bool C2C)
Get the Fourier mesh.
Definition fft.hpp:295
DDimFx::template Impl< DDimFx, Kokkos::HostSpace > init_fourier_space(ddc::DiscreteDomain< DDimX > x_mesh)
Initialize a Fourier discrete dimension.
Definition fft.hpp:254
FFT_Direction
A named argument to choose the direction of the FFT.
Definition fft.hpp:33
@ BACKWARD
Backward, corresponds to inverse FFT up to normalization.
@ FORWARD
Forward, corresponds to direct FFT up to normalization.
A templated tag representing a continuous dimension in the Fourier space associated to the original c...
Definition fft.hpp:25
A structure embedding the configuration of the exposed FFT function with the type of normalization.
Definition fft.hpp:318