DDC 0.1.0
Loading...
Searching...
No Matches
fft.hpp
1// Copyright (C) The DDC development team, see COPYRIGHT.md file
2//
3// SPDX-License-Identifier: MIT
4
5#pragma once
6
7#include <stdexcept>
8#include <type_traits>
9#include <utility>
10
11#include <ddc/ddc.hpp>
12
13#include <KokkosFFT.hpp>
14#include <Kokkos_Core.hpp>
15
16namespace ddc {
17
18/**
19 * @brief A templated tag representing a continuous dimension in the Fourier space associated to the original continuous dimension.
20 *
21 * @tparam The tag representing the original dimension.
22 */
23template <typename Dim>
24struct Fourier;
25
26/**
27 * @brief A named argument to choose the direction of the FFT.
28 *
29 * @see kwArgs_impl, kwArgs_fft
30 */
31enum class FFT_Direction {
32 FORWARD, ///< Forward, corresponds to direct FFT up to normalization
33 BACKWARD ///< Backward, corresponds to inverse FFT up to normalization
34};
35
36/**
37 * @brief A named argument to choose the type of normalization of the FFT.
38 *
39 * @see kwArgs_impl, kwArgs_fft
40 */
41enum class FFT_Normalization {
42 OFF, ///< No normalization. Un-normalized FFT is sum_j f(x_j)*e^-ikx_j
43 FORWARD, ///< Multiply by 1/N for forward FFT, no normalization for backward FFT
44 BACKWARD, ///< No normalization for forward FFT, multiply by 1/N for backward FFT
45 ORTHO, ///< Multiply by 1/sqrt(N)
46 FULL /**<
47 * Multiply by dx/sqrt(2*pi) for forward FFT and dk/sqrt(2*pi) for backward
48 * FFT. It is aligned with the usual definition of the (continuous) Fourier transform
49 * 1/sqrt(2*pi)*int f(x)*e^-ikx*dx, and thus may be relevant for spectral analysis applications.
50 */
51};
52
53} // namespace ddc
54
55namespace ddc::detail::fft {
56
57template <typename T>
58struct real_type
59{
60 using type = T;
61};
62
63template <typename T>
64struct real_type<Kokkos::complex<T>>
65{
66 using type = T;
67};
68
69template <typename T>
70using real_type_t = typename real_type<T>::type;
71
72// is_complex : trait to determine if type is Kokkos::complex<something>
73template <typename T>
74struct is_complex : std::false_type
75{
76};
77
78template <typename T>
79struct is_complex<Kokkos::complex<T>> : std::true_type
80{
81};
82
83template <typename T>
84constexpr bool is_complex_v = is_complex<T>::value;
85
86// LastSelector: returns a if Dim==Last, else b
87template <typename T, typename Dim, typename Last>
88KOKKOS_FUNCTION constexpr T LastSelector(const T a, const T b)
89{
90 return std::is_same_v<Dim, Last> ? a : b;
91}
92
93template <typename T, typename Dim, typename First, typename Second, typename... Tail>
94KOKKOS_FUNCTION constexpr T LastSelector(const T a, const T b)
95{
96 return LastSelector<T, Dim, Second, Tail...>(a, b);
97}
98
99/*
100 * @brief A structure embedding the configuration of the impl FFT function: direction and type of normalization.
101 *
102 * @see FFT_impl
103 */
104struct kwArgs_impl
105{
107 direction; // Only effective for C2C transform and for normalization BACKWARD and FORWARD
108 ddc::FFT_Normalization normalization;
109};
110
111/**
112 * @brief Get the mesh size along a given dimension.
113 *
114 * @tparam DDim The dimension along which the mesh size is returned.
115 * @param x_mesh The mesh.
116 *
117 * @return The mesh size along the required dimension.
118 */
119template <typename DDim, typename... DDimX>
120int N(ddc::DiscreteDomain<DDimX...> x_mesh)
121{
122 static_assert(
123 (is_uniform_point_sampling_v<DDimX> && ...),
124 "DDimX dimensions should derive from UniformPointSampling");
125 return static_cast<int>(x_mesh.template extent<DDim>());
126}
127
128template <typename... DDimX>
129KokkosFFT::axis_type<sizeof...(DDimX)> axes()
130{
131 return KokkosFFT::axis_type<sizeof...(DDimX)> {
132 static_cast<int>(ddc::type_seq_rank_v<DDimX, ddc::detail::TypeSeq<DDimX...>>)...};
133}
134
135inline KokkosFFT::Normalization ddc_fft_normalization_to_kokkos_fft(
136 FFT_Normalization const ddc_fft_normalization)
137{
138 if (ddc_fft_normalization == ddc::FFT_Normalization::OFF
139 || ddc_fft_normalization == ddc::FFT_Normalization::FULL) {
140 return KokkosFFT::Normalization::none;
141 }
142
143 if (ddc_fft_normalization == ddc::FFT_Normalization::FORWARD) {
144 return KokkosFFT::Normalization::forward;
145 }
146
147 if (ddc_fft_normalization == ddc::FFT_Normalization::BACKWARD) {
148 return KokkosFFT::Normalization::backward;
149 }
150
151 if (ddc_fft_normalization == ddc::FFT_Normalization::ORTHO) {
152 return KokkosFFT::Normalization::ortho;
153 }
154
155 throw std::runtime_error("ddc::FFT_Normalization not handled");
156}
157
158template <
159 typename ExecSpace,
160 typename ElementType,
161 typename DDom,
162 typename Layout,
163 typename MemorySpace,
164 typename T>
165void rescale(
166 ExecSpace const& exec_space,
167 ddc::ChunkSpan<ElementType, DDom, Layout, MemorySpace> const& chunk_span,
168 T const& value)
169{
170 ddc::parallel_for_each(
171 "ddc_fft_normalization",
172 exec_space,
173 chunk_span.domain(),
174 KOKKOS_LAMBDA(typename DDom::discrete_element_type const i) {
175 chunk_span(i) *= value;
176 });
177}
178
179template <class DDim>
180Real forward_full_norm_coef(DiscreteDomain<DDim> const& ddom) noexcept
181{
182 return rlength(ddom) / Kokkos::sqrt(2 * Kokkos::numbers::pi_v<Real>)
183 / (ddom.extents() - 1).value();
184}
185
186template <class DDim>
187Real backward_full_norm_coef(DiscreteDomain<DDim> const& ddom) noexcept
188{
189 return 1 / (forward_full_norm_coef(ddom) * ddom.extents().value());
190}
191
192/// @brief Core internal function to perform the FFT.
193template <
194 typename Tin,
195 typename Tout,
196 typename ExecSpace,
197 typename MemorySpace,
198 typename LayoutIn,
199 typename LayoutOut,
200 typename... DDimIn,
201 typename... DDimOut>
202void impl(
203 ExecSpace const& exec_space,
204 ddc::ChunkSpan<Tin, ddc::DiscreteDomain<DDimIn...>, LayoutIn, MemorySpace> const& in,
205 ddc::ChunkSpan<Tout, ddc::DiscreteDomain<DDimOut...>, LayoutOut, MemorySpace> const& out,
206 kwArgs_impl const& kwargs)
207{
208 static_assert(
209 std::is_same_v<real_type_t<Tin>, float> || std::is_same_v<real_type_t<Tin>, double>,
210 "Base type of Tin (and Tout) must be float or double.");
211 static_assert(
212 std::is_same_v<real_type_t<Tin>, real_type_t<Tout>>,
213 "Types Tin and Tout must be based on same type (float or double)");
214 static_assert(
215 Kokkos::SpaceAccessibility<ExecSpace, MemorySpace>::accessible,
216 "MemorySpace has to be accessible for ExecutionSpace.");
217
218 Kokkos::View<
219 ddc::detail::mdspan_to_kokkos_element_t<Tin, sizeof...(DDimIn)>,
220 ddc::detail::mdspan_to_kokkos_layout_t<LayoutIn>,
221 MemorySpace> const in_view
222 = in.allocation_kokkos_view();
223 Kokkos::View<
224 ddc::detail::mdspan_to_kokkos_element_t<Tout, sizeof...(DDimIn)>,
225 ddc::detail::mdspan_to_kokkos_layout_t<LayoutOut>,
226 MemorySpace> const out_view
227 = out.allocation_kokkos_view();
228 KokkosFFT::Normalization const kokkos_fft_normalization
229 = ddc_fft_normalization_to_kokkos_fft(kwargs.normalization);
230
231 // C2C
232 if constexpr (std::is_same_v<Tin, Tout>) {
233 if (kwargs.direction == ddc::FFT_Direction::FORWARD) {
234 KokkosFFT::
235 fftn(exec_space,
236 in_view,
237 out_view,
238 axes<DDimIn...>(),
239 kokkos_fft_normalization);
240 } else {
241 KokkosFFT::
242 ifftn(exec_space,
243 in_view,
244 out_view,
245 axes<DDimIn...>(),
246 kokkos_fft_normalization);
247 }
248 // R2C & C2R
249 } else {
250 if constexpr (is_complex_v<Tout>) {
251 assert(kwargs.direction == ddc::FFT_Direction::FORWARD);
252 KokkosFFT::
253 rfftn(exec_space,
254 in_view,
255 out_view,
256 axes<DDimIn...>(),
257 kokkos_fft_normalization);
258 } else {
259 assert(kwargs.direction == ddc::FFT_Direction::BACKWARD);
260 KokkosFFT::
261 irfftn(exec_space,
262 in_view,
263 out_view,
264 axes<DDimIn...>(),
265 kokkos_fft_normalization);
266 }
267 }
268
269 // The FULL normalization is mesh-dependant and thus handled by DDC
270 if (kwargs.normalization == ddc::FFT_Normalization::FULL) {
271 Real norm_coef;
272 if (kwargs.direction == ddc::FFT_Direction::FORWARD) {
273 DiscreteDomain<DDimIn...> const ddom_in = in.domain();
274 norm_coef = (forward_full_norm_coef(DiscreteDomain<DDimIn>(ddom_in)) * ...);
275 } else {
276 DiscreteDomain<DDimOut...> const ddom_out = out.domain();
277 norm_coef = (backward_full_norm_coef(DiscreteDomain<DDimOut>(ddom_out)) * ...);
278 }
279
280 rescale(exec_space, out, static_cast<real_type_t<Tout>>(norm_coef));
281 }
282}
283
284} // namespace ddc::detail::fft
285
286namespace ddc {
287
288/**
289 * @brief Initialize a Fourier discrete dimension.
290 *
291 * Initialize the (1D) discrete space representing the Fourier discrete dimension associated
292 * to the (1D) mesh passed as argument. It is a N-periodic PeriodicSampling with a periodic window of width 2*pi/dx.
293 *
294 * This value comes from the Nyquist-Shannon theorem: the period of the spectral domain is N*dk = 2*pi/dx.
295 * Adding to this the relations dx = (xmax-xmin)/(N-1), and dk = (kmax-kmin)/(N-1), we get kmax-kmin = 2*pi*(N-1)^2/N/(xmax-xmin),
296 * which is used in the implementation (xmax, xmin, kmin and kmax are the centers of lower and upper cells inside a single period of the meshes).
297 *
298 * @tparam DDimFx A PeriodicSampling representing the Fourier discrete dimension.
299 * @tparam DDimX The type of the original discrete dimension.
300 *
301 * @param x_mesh The DiscreteDomain representing the (1D) original mesh.
302 *
303 * @return The initialized Impl representing the discrete Fourier space.
304 *
305 * @see PeriodicSampling
306 */
307template <typename DDimFx, typename DDimX>
308typename DDimFx::template Impl<DDimFx, Kokkos::HostSpace> init_fourier_space(
309 ddc::DiscreteDomain<DDimX> x_mesh)
310{
311 static_assert(
312 is_uniform_point_sampling_v<DDimX>,
313 "DDimX dimensions should derive from UniformPointSampling");
314 static_assert(
315 is_periodic_sampling_v<DDimFx>,
316 "DDimFx dimensions should derive from PeriodicSampling");
317 auto [impl, ddom] = DDimFx::template init<DDimFx>(
318 ddc::Coordinate<typename DDimFx::continuous_dimension_type>(0),
319 ddc::Coordinate<typename DDimFx::continuous_dimension_type>(
320 2 * (ddc::detail::fft::N<DDimX>(x_mesh) - 1)
321 * (ddc::detail::fft::N<DDimX>(x_mesh) - 1)
322 / static_cast<double>(
323 ddc::detail::fft::N<DDimX>(x_mesh)
324 * (ddc::coordinate(x_mesh.back()) - ddc::coordinate(x_mesh.front())))
325 * Kokkos::numbers::pi),
326 ddc::DiscreteVector<DDimFx>(ddc::detail::fft::N<DDimX>(x_mesh)),
327 ddc::DiscreteVector<DDimFx>(ddc::detail::fft::N<DDimX>(x_mesh)));
328 return std::move(impl);
329}
330
331/**
332 * @brief Get the Fourier mesh.
333 *
334 * Compute the Fourier (or spectral) mesh on which the Discrete Fourier Transform of a
335 * discrete function is defined.
336 *
337 * @param x_mesh The DiscreteDomain representing the original mesh.
338 * @param C2C A flag indicating if a complex-to-complex DFT is going to be performed. Indeed,
339 * in this case the two meshes have same number of points, whereas for real-to-complex
340 * or complex-to-real DFT, each complex value of the Fourier-transformed function contains twice more
341 * information, and thus only half (actually Nx*Ny*(Nz/2+1) for 3D R2C FFT to take in account mode 0)
342 * values are needed (cf. DFT conjugate symmetry property for more information about this).
343 *
344 * @return The domain representing the Fourier mesh.
345 */
346template <typename... DDimFx, typename... DDimX>
347ddc::DiscreteDomain<DDimFx...> fourier_mesh(ddc::DiscreteDomain<DDimX...> x_mesh, bool C2C)
348{
349 static_assert(
350 (is_uniform_point_sampling_v<DDimX> && ...),
351 "DDimX dimensions should derive from UniformPointSampling");
352 static_assert(
353 (is_periodic_sampling_v<DDimFx> && ...),
354 "DDimFx dimensions should derive from PeriodicPointSampling");
355 return ddc::DiscreteDomain<DDimFx...>(ddc::DiscreteDomain<DDimFx>(
356 ddc::DiscreteElement<DDimFx>(0),
357 ddc::DiscreteVector<DDimFx>(
358 (C2C ? ddc::detail::fft::N<DDimX>(x_mesh)
359 : ddc::detail::fft::LastSelector<int, DDimX, DDimX...>(
360 ddc::detail::fft::N<DDimX>(x_mesh) / 2 + 1,
361 ddc::detail::fft::N<DDimX>(x_mesh)))))...);
362}
363
364/**
365 * @brief Get the Fourier mesh.
366 *
367 * Compute the Fourier (or spectral) mesh on which the Discrete Fourier Transform of a
368 * discrete function is defined.
369 *
370 * @deprecated Use @ref fourier_mesh instead.
371 *
372 * @param x_mesh The DiscreteDomain representing the original mesh.
373 * @param C2C A flag indicating if a complex-to-complex DFT is going to be performed. Indeed,
374 * in this case the two meshes have same number of points, whereas for real-to-complex
375 * or complex-to-real DFT, each complex value of the Fourier-transformed function contains twice more
376 * information, and thus only half (actually Nx*Ny*(Nz/2+1) for 3D R2C FFT to take in account mode 0)
377 * values are needed (cf. DFT conjugate symmetry property for more information about this).
378 *
379 * @return The domain representing the Fourier mesh.
380 */
381template <typename... DDimFx, typename... DDimX>
382[[deprecated("Use `fourier_mesh` instead")]] ddc::DiscreteDomain<DDimFx...> FourierMesh(
383 ddc::DiscreteDomain<DDimX...> x_mesh,
384 bool C2C)
385{
386 return fourier_mesh<DDimFx...>(x_mesh, C2C);
387}
388
389/**
390 * @brief A structure embedding the configuration of the exposed FFT function with the type of normalization.
391 *
392 * @see fft, ifft
393 */
394struct kwArgs_fft
395{
397 normalization; ///< Enum member to identify the type of normalization performed
398};
399
400/**
401 * @brief Perform a direct Fast Fourier Transform.
402 *
403 * Compute the discrete Fourier transform of a function using the specialized implementation for the Kokkos::ExecutionSpace
404 * of the FFT algorithm.
405 *
406 * @tparam Tin The type of the input elements (float, Kokkos::complex<float>, double or Kokkos::complex<double>).
407 * @tparam Tout The type of the output elements (Kokkos::complex<float> or Kokkos::complex<double>).
408 * @tparam DDimFx... The parameter pack of the Fourier discrete dimensions.
409 * @tparam DDimX... The parameter pack of the original discrete dimensions.
410 * @tparam ExecSpace The type of the Kokkos::ExecutionSpace on which the FFT is performed. It determines which specialized
411 * backend is used (ie. fftw, cuFFT...).
412 * @tparam MemorySpace The type of the Kokkos::MemorySpace on which are stored the input and output discrete functions.
413 * @tparam LayoutIn The layout of the Chunkspan representing the input discrete function.
414 * @tparam LayoutOut The layout of the Chunkspan representing the output discrete function.
415 *
416 * @param exec_space The Kokkos::ExecutionSpace on which the FFT is performed.
417 * @param out The output discrete function, represented as a ChunkSpan storing values on a spectral mesh.
418 * @param in The input discrete function, represented as a ChunkSpan storing values on a mesh.
419 * @param kwargs The kwArgs_fft configuring the FFT.
420 */
421template <
422 typename Tin,
423 typename Tout,
424 typename... DDimFx,
425 typename... DDimX,
426 typename ExecSpace,
427 typename MemorySpace,
428 typename LayoutIn,
429 typename LayoutOut>
430void fft(
431 ExecSpace const& exec_space,
432 ddc::ChunkSpan<Tout, ddc::DiscreteDomain<DDimFx...>, LayoutOut, MemorySpace> out,
433 ddc::ChunkSpan<Tin, ddc::DiscreteDomain<DDimX...>, LayoutIn, MemorySpace> in,
435{
436 static_assert(
437 std::is_same_v<LayoutIn, Kokkos::layout_right>
438 && std::is_same_v<LayoutOut, Kokkos::layout_right>,
439 "Layouts must be right-handed");
440 static_assert(
441 (is_uniform_point_sampling_v<DDimX> && ...),
442 "DDimX dimensions should derive from UniformPointSampling");
443 static_assert(
444 (is_periodic_sampling_v<DDimFx> && ...),
445 "DDimFx dimensions should derive from PeriodicPointSampling");
446
447 ddc::detail::fft::
448 impl(exec_space, in, out, {ddc::FFT_Direction::FORWARD, kwargs.normalization});
449}
450
451/**
452 * @brief Perform an inverse Fast Fourier Transform.
453 *
454 * Compute the inverse discrete Fourier transform of a spectral function using the specialized implementation for the Kokkos::ExecutionSpace
455 * of the iFFT algorithm.
456 *
457 * @warning C2R iFFT does NOT preserve input.
458 *
459 * @tparam Tin The type of the input elements (Kokkos::complex<float> or Kokkos::complex<double>).
460 * @tparam Tout The type of the output elements (float, Kokkos::complex<float>, double or Kokkos::complex<double>).
461 * @tparam DDimX... The parameter pack of the original discrete dimensions.
462 * @tparam DDimFx... The parameter pack of the Fourier discrete dimensions.
463 * @tparam ExecSpace The type of the Kokkos::ExecutionSpace on which the iFFT is performed. It determines which specialized
464 * backend is used (ie. fftw, cuFFT...).
465 * @tparam MemorySpace The type of the Kokkos::MemorySpace on which are stored the input and output discrete functions.
466 * @tparam LayoutIn The layout of the Chunkspan representing the input discrete function.
467 * @tparam LayoutOut The layout of the Chunkspan representing the output discrete function.
468 *
469 * @param exec_space The Kokkos::ExecutionSpace on which the iFFT is performed.
470 * @param out The output discrete function, represented as a ChunkSpan storing values on a mesh.
471 * @param in The input discrete function, represented as a ChunkSpan storing values on a spectral mesh.
472 * @param kwargs The kwArgs_fft configuring the iFFT.
473 */
474template <
475 typename Tin,
476 typename Tout,
477 typename... DDimX,
478 typename... DDimFx,
479 typename ExecSpace,
480 typename MemorySpace,
481 typename LayoutIn,
482 typename LayoutOut>
483void ifft(
484 ExecSpace const& exec_space,
485 ddc::ChunkSpan<Tout, ddc::DiscreteDomain<DDimX...>, LayoutOut, MemorySpace> out,
486 ddc::ChunkSpan<Tin, ddc::DiscreteDomain<DDimFx...>, LayoutIn, MemorySpace> in,
488{
489 static_assert(
490 std::is_same_v<LayoutIn, Kokkos::layout_right>
491 && std::is_same_v<LayoutOut, Kokkos::layout_right>,
492 "Layouts must be right-handed");
493 static_assert(
494 (is_uniform_point_sampling_v<DDimX> && ...),
495 "DDimX dimensions should derive from UniformPointSampling");
496 static_assert(
497 (is_periodic_sampling_v<DDimFx> && ...),
498 "DDimFx dimensions should derive from PeriodicPointSampling");
499
500 ddc::detail::fft::
501 impl(exec_space, in, out, {ddc::FFT_Direction::BACKWARD, kwargs.normalization});
502}
503
504} // namespace ddc
friend class DiscreteDomain
KOKKOS_FUNCTION constexpr bool operator!=(DiscreteVector< OTags... > const &rhs) const noexcept
The top-level namespace of DDC.
ddc::FFT_Normalization normalization
Enum member to identify the type of normalization performed.
Definition fft.hpp:397
void ifft(ExecSpace const &exec_space, ddc::ChunkSpan< Tout, ddc::DiscreteDomain< DDimX... >, LayoutOut, MemorySpace > out, ddc::ChunkSpan< Tin, ddc::DiscreteDomain< DDimFx... >, LayoutIn, MemorySpace > in, ddc::kwArgs_fft kwargs={ddc::FFT_Normalization::OFF})
Perform an inverse Fast Fourier Transform.
Definition fft.hpp:483
void fft(ExecSpace const &exec_space, ddc::ChunkSpan< Tout, ddc::DiscreteDomain< DDimFx... >, LayoutOut, MemorySpace > out, ddc::ChunkSpan< Tin, ddc::DiscreteDomain< DDimX... >, LayoutIn, MemorySpace > in, ddc::kwArgs_fft kwargs={ddc::FFT_Normalization::OFF})
Perform a direct Fast Fourier Transform.
Definition fft.hpp:430
ddc::DiscreteDomain< DDimFx... > FourierMesh(ddc::DiscreteDomain< DDimX... > x_mesh, bool C2C)
Get the Fourier mesh.
Definition fft.hpp:382
FFT_Normalization
A named argument to choose the type of normalization of the FFT.
Definition fft.hpp:41
@ BACKWARD
No normalization for forward FFT, multiply by 1/N for backward FFT.
@ OFF
No normalization. Un-normalized FFT is sum_j f(x_j)*e^-ikx_j.
@ ORTHO
Multiply by 1/sqrt(N)
@ FULL
Multiply by dx/sqrt(2*pi) for forward FFT and dk/sqrt(2*pi) for backward FFT.
@ FORWARD
Multiply by 1/N for forward FFT, no normalization for backward FFT.
ddc::DiscreteDomain< DDimFx... > fourier_mesh(ddc::DiscreteDomain< DDimX... > x_mesh, bool C2C)
Get the Fourier mesh.
Definition fft.hpp:347
DDimFx::template Impl< DDimFx, Kokkos::HostSpace > init_fourier_space(ddc::DiscreteDomain< DDimX > x_mesh)
Initialize a Fourier discrete dimension.
Definition fft.hpp:308
FFT_Direction
A named argument to choose the direction of the FFT.
Definition fft.hpp:31
@ BACKWARD
Backward, corresponds to inverse FFT up to normalization.
@ FORWARD
Forward, corresponds to direct FFT up to normalization.
A structure embedding the configuration of the exposed FFT function with the type of normalization.
Definition fft.hpp:395