DDC 0.10.0
Loading...
Searching...
No Matches
fft.hpp
1// Copyright (C) The DDC development team, see COPYRIGHT.md file
2//
3// SPDX-License-Identifier: MIT
4
5#pragma once
6
7#include <cassert>
8#include <stdexcept>
9#include <type_traits>
10#include <utility>
11
12#include <ddc/ddc.hpp>
13
14#include <KokkosFFT.hpp>
15#include <Kokkos_Core.hpp>
16
17namespace ddc {
18
19/**
20 * @brief A templated tag representing a continuous dimension in the Fourier space associated to the original continuous dimension.
21 *
22 * @tparam The tag representing the original dimension.
23 */
24template <typename Dim>
25struct Fourier;
26
27/**
28 * @brief A named argument to choose the direction of the FFT.
29 *
30 * @see kwArgsImpl, kwArgs_fft
31 */
32enum class FFT_Direction {
33 FORWARD, ///< Forward, corresponds to direct FFT up to normalization
34 BACKWARD ///< Backward, corresponds to inverse FFT up to normalization
35};
36
37/**
38 * @brief A named argument to choose the type of normalization of the FFT.
39 *
40 * @see kwArgsImpl, kwArgs_fft
41 */
42enum class FFT_Normalization {
43 OFF, ///< No normalization. Un-normalized FFT is sum_j f(x_j)*e^-ikx_j
44 FORWARD, ///< Multiply by 1/N for forward FFT, no normalization for backward FFT
45 BACKWARD, ///< No normalization for forward FFT, multiply by 1/N for backward FFT
46 ORTHO, ///< Multiply by 1/sqrt(N)
47 FULL /**<
48 * Multiply by dx/sqrt(2*pi) for forward FFT and dk/sqrt(2*pi) for backward
49 * FFT. It is aligned with the usual definition of the (continuous) Fourier transform
50 * 1/sqrt(2*pi)*int f(x)*e^-ikx*dx, and thus may be relevant for spectral analysis applications.
51 */
52};
53
54} // namespace ddc
55
56namespace ddc::detail::fft {
57
58template <typename T>
59struct RealType
60{
61 using type = T;
62};
63
64template <typename T>
65struct RealType<Kokkos::complex<T>>
66{
67 using type = T;
68};
69
70template <typename T>
71using real_type_t = typename RealType<T>::type;
72
73// is_complex : trait to determine if type is Kokkos::complex<something>
74template <typename T>
75struct is_complex : std::false_type
76{
77};
78
79template <typename T>
80struct is_complex<Kokkos::complex<T>> : std::true_type
81{
82};
83
84template <typename T>
85constexpr bool is_complex_v = is_complex<T>::value;
86
87/*
88 * @brief A structure embedding the configuration of the impl FFT function: direction and type of normalization.
89 *
90 * @see FFT_impl
91 */
92struct KwArgsImpl
93{
95 direction; // Only effective for C2C transform and for normalization BACKWARD and FORWARD
96 ddc::FFT_Normalization normalization;
97};
98
99template <typename... DDimX>
100KokkosFFT::axis_type<sizeof...(DDimX)> axes()
101{
102 return KokkosFFT::axis_type<sizeof...(DDimX)> {
103 static_cast<int>(ddc::type_seq_rank_v<DDimX, ddc::detail::TypeSeq<DDimX...>>)...};
104}
105
106inline KokkosFFT::Normalization ddc_fft_normalization_to_kokkos_fft(
107 FFT_Normalization const ddc_fft_normalization)
108{
109 if (ddc_fft_normalization == ddc::FFT_Normalization::OFF
110 || ddc_fft_normalization == ddc::FFT_Normalization::FULL) {
111 return KokkosFFT::Normalization::none;
112 }
113
114 if (ddc_fft_normalization == ddc::FFT_Normalization::FORWARD) {
115 return KokkosFFT::Normalization::forward;
116 }
117
118 if (ddc_fft_normalization == ddc::FFT_Normalization::BACKWARD) {
119 return KokkosFFT::Normalization::backward;
120 }
121
122 if (ddc_fft_normalization == ddc::FFT_Normalization::ORTHO) {
123 return KokkosFFT::Normalization::ortho;
124 }
125
126 throw std::runtime_error("ddc::FFT_Normalization not handled");
127}
128
129template <class T>
130class ScaleFn
131{
132 T m_coef;
133
134public:
135 explicit ScaleFn(T coef) noexcept : m_coef(std::move(coef)) {}
136
137 template <class U>
138 [[nodiscard]] KOKKOS_FUNCTION U operator()(U const& value) const noexcept
139 {
140 return m_coef * value;
141 }
142};
143
144template <class DDim>
145Real forward_full_norm_coef(DiscreteDomain<DDim> const& ddom) noexcept
146{
147 return rlength(ddom) / Kokkos::sqrt(2 * Kokkos::numbers::pi_v<Real>)
148 / (ddom.extents() - 1).value();
149}
150
151template <class DDim>
152Real backward_full_norm_coef(DiscreteDomain<DDim> const& ddom) noexcept
153{
154 return 1 / (forward_full_norm_coef(ddom) * ddom.extents().value());
155}
156
157/// @brief Core internal function to perform the FFT.
158template <
159 typename Tin,
160 typename Tout,
161 typename ExecSpace,
162 typename MemorySpace,
163 typename LayoutIn,
164 typename LayoutOut,
165 typename... DDimIn,
166 typename... DDimOut>
167void impl(
168 ExecSpace const& exec_space,
169 ddc::ChunkSpan<Tin, ddc::DiscreteDomain<DDimIn...>, LayoutIn, MemorySpace> const& in,
170 ddc::ChunkSpan<Tout, ddc::DiscreteDomain<DDimOut...>, LayoutOut, MemorySpace> const& out,
171 KwArgsImpl const& kwargs)
172{
173 static_assert(
174 std::is_same_v<real_type_t<Tin>, float> || std::is_same_v<real_type_t<Tin>, double>,
175 "Base type of Tin (and Tout) must be float or double.");
176 static_assert(
177 std::is_same_v<real_type_t<Tin>, real_type_t<Tout>>,
178 "Types Tin and Tout must be based on same type (float or double)");
179 static_assert(
180 Kokkos::SpaceAccessibility<ExecSpace, MemorySpace>::accessible,
181 "MemorySpace has to be accessible for ExecutionSpace.");
182
183 Kokkos::View<
184 ddc::detail::mdspan_to_kokkos_element_t<Tin, sizeof...(DDimIn)>,
185 ddc::detail::mdspan_to_kokkos_layout_t<LayoutIn>,
186 MemorySpace> const in_view
187 = in.allocation_kokkos_view();
188 Kokkos::View<
189 ddc::detail::mdspan_to_kokkos_element_t<Tout, sizeof...(DDimIn)>,
190 ddc::detail::mdspan_to_kokkos_layout_t<LayoutOut>,
191 MemorySpace> const out_view
192 = out.allocation_kokkos_view();
193 KokkosFFT::Normalization const kokkos_fft_normalization
194 = ddc_fft_normalization_to_kokkos_fft(kwargs.normalization);
195
196 // C2C
197 if constexpr (std::is_same_v<Tin, Tout>) {
198 if (kwargs.direction == ddc::FFT_Direction::FORWARD) {
199 KokkosFFT::
200 fftn(exec_space,
201 in_view,
202 out_view,
203 axes<DDimIn...>(),
204 kokkos_fft_normalization);
205 } else {
206 KokkosFFT::
207 ifftn(exec_space,
208 in_view,
209 out_view,
210 axes<DDimIn...>(),
211 kokkos_fft_normalization);
212 }
213 // R2C & C2R
214 } else {
215 if constexpr (is_complex_v<Tout>) {
216 assert(kwargs.direction == ddc::FFT_Direction::FORWARD);
217 KokkosFFT::
218 rfftn(exec_space,
219 in_view,
220 out_view,
221 axes<DDimIn...>(),
222 kokkos_fft_normalization);
223 } else {
224 assert(kwargs.direction == ddc::FFT_Direction::BACKWARD);
225 KokkosFFT::
226 irfftn(exec_space,
227 in_view,
228 out_view,
229 axes<DDimIn...>(),
230 kokkos_fft_normalization);
231 }
232 }
233
234 // The FULL normalization is mesh-dependant and thus handled by DDC
235 if (kwargs.normalization == ddc::FFT_Normalization::FULL) {
236 Real norm_coef;
237 if (kwargs.direction == ddc::FFT_Direction::FORWARD) {
238 DiscreteDomain<DDimIn...> const ddom_in = in.domain();
239 norm_coef = (forward_full_norm_coef(DiscreteDomain<DDimIn>(ddom_in)) * ...);
240 } else {
241 DiscreteDomain<DDimOut...> const ddom_out = out.domain();
242 norm_coef = (backward_full_norm_coef(DiscreteDomain<DDimOut>(ddom_out)) * ...);
243 }
244
245 ddc::parallel_transform(exec_space, out, ScaleFn<real_type_t<Tout>>(norm_coef));
246 }
247}
248
249} // namespace ddc::detail::fft
250
251namespace ddc {
252
253/**
254 * @brief Initialize a Fourier discrete dimension.
255 *
256 * Initialize the (1D) discrete space representing the Fourier discrete dimension associated
257 * to the (1D) mesh passed as argument. It is a N-periodic PeriodicSampling with a periodic window of width 2*pi/dx.
258 *
259 * This value comes from the Nyquist-Shannon theorem: the period of the spectral domain is N*dk = 2*pi/dx.
260 * Adding to this the relations dx = (xmax-xmin)/(N-1), and dk = (kmax-kmin)/(N-1), we get kmax-kmin = 2*pi*(N-1)^2/N/(xmax-xmin),
261 * which is used in the implementation (xmax, xmin, kmin and kmax are the centers of lower and upper cells inside a single period of the meshes).
262 *
263 * @tparam DDimFx A PeriodicSampling representing the Fourier discrete dimension.
264 * @tparam DDimX The type of the original discrete dimension.
265 *
266 * @param x_mesh The DiscreteDomain representing the (1D) original mesh.
267 *
268 * @return The initialized Impl representing the discrete Fourier space.
269 *
270 * @see PeriodicSampling
271 */
272template <typename DDimFx, typename DDimX>
273typename DDimFx::template Impl<DDimFx, Kokkos::HostSpace> init_fourier_space(
274 ddc::DiscreteDomain<DDimX> x_mesh)
275{
276 static_assert(
277 is_uniform_point_sampling_v<DDimX>,
278 "DDimX dimension must derive from UniformPointSampling");
279 static_assert(
280 is_periodic_sampling_v<DDimFx>,
281 "DDimFx dimension must derive from PeriodicSampling");
282 using CDimFx = typename DDimFx::continuous_dimension_type;
283 using CDimX = typename DDimX::continuous_dimension_type;
284 static_assert(
285 std::is_same_v<CDimFx, ddc::Fourier<CDimX>>,
286 "DDimX and DDimFx dimensions must be defined over the same continuous dimension");
287
288 DiscreteVectorElement const nx = get<DDimX>(x_mesh.extents());
289 double const lx = ddc::rlength(x_mesh);
290 auto [impl, ddom] = DDimFx::template init<DDimFx>(
291 ddc::Coordinate<CDimFx>(0),
292 ddc::Coordinate<CDimFx>(2 * (nx - 1) * (nx - 1) / (nx * lx) * Kokkos::numbers::pi),
293 ddc::DiscreteVector<DDimFx>(nx),
294 ddc::DiscreteVector<DDimFx>(nx));
295 return std::move(impl);
296}
297
298/**
299 * @brief Get the Fourier mesh.
300 *
301 * Compute the Fourier (or spectral) mesh on which the Discrete Fourier Transform of a
302 * discrete function is defined.
303 *
304 * @param x_mesh The DiscreteDomain representing the original mesh.
305 * @param C2C A flag indicating if a complex-to-complex DFT is going to be performed. Indeed,
306 * in this case the two meshes have same number of points, whereas for real-to-complex
307 * or complex-to-real DFT, each complex value of the Fourier-transformed function contains twice more
308 * information, and thus only half (actually Nx*Ny*(Nz/2+1) for 3D R2C FFT to take in account mode 0)
309 * values are needed (cf. DFT conjugate symmetry property for more information about this).
310 *
311 * @return The domain representing the Fourier mesh.
312 */
313template <typename... DDimFx, typename... DDimX>
314ddc::DiscreteDomain<DDimFx...> fourier_mesh(ddc::DiscreteDomain<DDimX...> x_mesh, bool C2C)
315{
316 static_assert(
317 (is_uniform_point_sampling_v<DDimX> && ...),
318 "DDimX dimensions should derive from UniformPointSampling");
319 static_assert(
320 (is_periodic_sampling_v<DDimFx> && ...),
321 "DDimFx dimensions should derive from PeriodicPointSampling");
322 ddc::DiscreteVector<DDimX...> extents = x_mesh.extents();
323 if (!C2C) {
324 detail::array(extents).back() = detail::array(extents).back() / 2 + 1;
325 }
326 return ddc::DiscreteDomain<DDimFx...>(ddc::DiscreteDomain<DDimFx>(
327 ddc::DiscreteElement<DDimFx>(0),
328 ddc::DiscreteVector<DDimFx>(get<DDimX>(extents)))...);
329}
330
331/**
332 * @brief A structure embedding the configuration of the exposed FFT function with the type of normalization.
333 *
334 * @see fft, ifft
335 */
336struct kwArgs_fft
337{
339 normalization; ///< Enum member to identify the type of normalization performed
340};
341
342/**
343 * @brief Perform a direct Fast Fourier Transform.
344 *
345 * Compute the discrete Fourier transform of a function using the specialized implementation for the Kokkos::ExecutionSpace
346 * of the FFT algorithm.
347 *
348 * @tparam Tin The type of the input elements (float, Kokkos::complex<float>, double or Kokkos::complex<double>).
349 * @tparam Tout The type of the output elements (Kokkos::complex<float> or Kokkos::complex<double>).
350 * @tparam DDimFx... The parameter pack of the Fourier discrete dimensions.
351 * @tparam DDimX... The parameter pack of the original discrete dimensions.
352 * @tparam ExecSpace The type of the Kokkos::ExecutionSpace on which the FFT is performed. It determines which specialized
353 * backend is used (ie. fftw, cuFFT...).
354 * @tparam MemorySpace The type of the Kokkos::MemorySpace on which are stored the input and output discrete functions.
355 * @tparam LayoutIn The layout of the Chunkspan representing the input discrete function.
356 * @tparam LayoutOut The layout of the Chunkspan representing the output discrete function.
357 *
358 * @param exec_space The Kokkos::ExecutionSpace on which the FFT is performed.
359 * @param out The output discrete function, represented as a ChunkSpan storing values on a spectral mesh.
360 * @param in The input discrete function, represented as a ChunkSpan storing values on a mesh.
361 * @param kwargs The kwArgs_fft configuring the FFT.
362 */
363template <
364 typename Tin,
365 typename Tout,
366 typename... DDimFx,
367 typename... DDimX,
368 typename ExecSpace,
369 typename MemorySpace,
370 typename LayoutIn,
371 typename LayoutOut>
372void fft(
373 ExecSpace const& exec_space,
374 ddc::ChunkSpan<Tout, ddc::DiscreteDomain<DDimFx...>, LayoutOut, MemorySpace> out,
375 ddc::ChunkSpan<Tin, ddc::DiscreteDomain<DDimX...>, LayoutIn, MemorySpace> in,
377{
378 static_assert(
379 std::is_same_v<LayoutIn, Kokkos::layout_right>
380 && std::is_same_v<LayoutOut, Kokkos::layout_right>,
381 "Layouts must be right-handed");
382 static_assert(
383 (is_uniform_point_sampling_v<DDimX> && ...),
384 "DDimX dimensions should derive from UniformPointSampling");
385 static_assert(
386 (is_periodic_sampling_v<DDimFx> && ...),
387 "DDimFx dimensions should derive from PeriodicPointSampling");
388
389 ddc::detail::fft::
390 impl(exec_space, in, out, {ddc::FFT_Direction::FORWARD, kwargs.normalization});
391}
392
393/**
394 * @brief Perform an inverse Fast Fourier Transform.
395 *
396 * Compute the inverse discrete Fourier transform of a spectral function using the specialized implementation for the Kokkos::ExecutionSpace
397 * of the iFFT algorithm.
398 *
399 * @warning C2R iFFT does NOT preserve input.
400 *
401 * @tparam Tin The type of the input elements (Kokkos::complex<float> or Kokkos::complex<double>).
402 * @tparam Tout The type of the output elements (float, Kokkos::complex<float>, double or Kokkos::complex<double>).
403 * @tparam DDimX... The parameter pack of the original discrete dimensions.
404 * @tparam DDimFx... The parameter pack of the Fourier discrete dimensions.
405 * @tparam ExecSpace The type of the Kokkos::ExecutionSpace on which the iFFT is performed. It determines which specialized
406 * backend is used (ie. fftw, cuFFT...).
407 * @tparam MemorySpace The type of the Kokkos::MemorySpace on which are stored the input and output discrete functions.
408 * @tparam LayoutIn The layout of the Chunkspan representing the input discrete function.
409 * @tparam LayoutOut The layout of the Chunkspan representing the output discrete function.
410 *
411 * @param exec_space The Kokkos::ExecutionSpace on which the iFFT is performed.
412 * @param out The output discrete function, represented as a ChunkSpan storing values on a mesh.
413 * @param in The input discrete function, represented as a ChunkSpan storing values on a spectral mesh.
414 * @param kwargs The kwArgs_fft configuring the iFFT.
415 */
416template <
417 typename Tin,
418 typename Tout,
419 typename... DDimX,
420 typename... DDimFx,
421 typename ExecSpace,
422 typename MemorySpace,
423 typename LayoutIn,
424 typename LayoutOut>
425void ifft(
426 ExecSpace const& exec_space,
427 ddc::ChunkSpan<Tout, ddc::DiscreteDomain<DDimX...>, LayoutOut, MemorySpace> out,
428 ddc::ChunkSpan<Tin, ddc::DiscreteDomain<DDimFx...>, LayoutIn, MemorySpace> in,
430{
431 static_assert(
432 std::is_same_v<LayoutIn, Kokkos::layout_right>
433 && std::is_same_v<LayoutOut, Kokkos::layout_right>,
434 "Layouts must be right-handed");
435 static_assert(
436 (is_uniform_point_sampling_v<DDimX> && ...),
437 "DDimX dimensions should derive from UniformPointSampling");
438 static_assert(
439 (is_periodic_sampling_v<DDimFx> && ...),
440 "DDimFx dimensions should derive from PeriodicPointSampling");
441
442 ddc::detail::fft::
443 impl(exec_space, in, out, {ddc::FFT_Direction::BACKWARD, kwargs.normalization});
444}
445
446} // namespace ddc
friend class ChunkSpan
friend class DiscreteDomain
KOKKOS_FUNCTION constexpr bool operator!=(DiscreteVector< OTags... > const &rhs) const noexcept
The top-level namespace of DDC.
ddc::FFT_Normalization normalization
Enum member to identify the type of normalization performed.
Definition fft.hpp:339
void ifft(ExecSpace const &exec_space, ddc::ChunkSpan< Tout, ddc::DiscreteDomain< DDimX... >, LayoutOut, MemorySpace > out, ddc::ChunkSpan< Tin, ddc::DiscreteDomain< DDimFx... >, LayoutIn, MemorySpace > in, ddc::kwArgs_fft kwargs={ddc::FFT_Normalization::OFF})
Perform an inverse Fast Fourier Transform.
Definition fft.hpp:425
void fft(ExecSpace const &exec_space, ddc::ChunkSpan< Tout, ddc::DiscreteDomain< DDimFx... >, LayoutOut, MemorySpace > out, ddc::ChunkSpan< Tin, ddc::DiscreteDomain< DDimX... >, LayoutIn, MemorySpace > in, ddc::kwArgs_fft kwargs={ddc::FFT_Normalization::OFF})
Perform a direct Fast Fourier Transform.
Definition fft.hpp:372
FFT_Normalization
A named argument to choose the type of normalization of the FFT.
Definition fft.hpp:42
@ BACKWARD
No normalization for forward FFT, multiply by 1/N for backward FFT.
@ OFF
No normalization. Un-normalized FFT is sum_j f(x_j)*e^-ikx_j.
@ ORTHO
Multiply by 1/sqrt(N)
@ FULL
Multiply by dx/sqrt(2*pi) for forward FFT and dk/sqrt(2*pi) for backward FFT.
@ FORWARD
Multiply by 1/N for forward FFT, no normalization for backward FFT.
ddc::DiscreteDomain< DDimFx... > fourier_mesh(ddc::DiscreteDomain< DDimX... > x_mesh, bool C2C)
Get the Fourier mesh.
Definition fft.hpp:314
DDimFx::template Impl< DDimFx, Kokkos::HostSpace > init_fourier_space(ddc::DiscreteDomain< DDimX > x_mesh)
Initialize a Fourier discrete dimension.
Definition fft.hpp:273
FFT_Direction
A named argument to choose the direction of the FFT.
Definition fft.hpp:32
@ BACKWARD
Backward, corresponds to inverse FFT up to normalization.
@ FORWARD
Forward, corresponds to direct FFT up to normalization.
A structure embedding the configuration of the exposed FFT function with the type of normalization.
Definition fft.hpp:337