DDC 0.4.1
Loading...
Searching...
No Matches
fft.hpp
1// Copyright (C) The DDC development team, see COPYRIGHT.md file
2//
3// SPDX-License-Identifier: MIT
4
5#pragma once
6
7#include <stdexcept>
8#include <type_traits>
9#include <utility>
10
11#include <ddc/ddc.hpp>
12
13#include <KokkosFFT.hpp>
14#include <Kokkos_Core.hpp>
15
16namespace ddc {
17
18/**
19 * @brief A templated tag representing a continuous dimension in the Fourier space associated to the original continuous dimension.
20 *
21 * @tparam The tag representing the original dimension.
22 */
23template <typename Dim>
24struct Fourier;
25
26/**
27 * @brief A named argument to choose the direction of the FFT.
28 *
29 * @see kwArgs_impl, kwArgs_fft
30 */
31enum class FFT_Direction {
32 FORWARD, ///< Forward, corresponds to direct FFT up to normalization
33 BACKWARD ///< Backward, corresponds to inverse FFT up to normalization
34};
35
36/**
37 * @brief A named argument to choose the type of normalization of the FFT.
38 *
39 * @see kwArgs_impl, kwArgs_fft
40 */
41enum class FFT_Normalization {
42 OFF, ///< No normalization. Un-normalized FFT is sum_j f(x_j)*e^-ikx_j
43 FORWARD, ///< Multiply by 1/N for forward FFT, no normalization for backward FFT
44 BACKWARD, ///< No normalization for forward FFT, multiply by 1/N for backward FFT
45 ORTHO, ///< Multiply by 1/sqrt(N)
46 FULL /**<
47 * Multiply by dx/sqrt(2*pi) for forward FFT and dk/sqrt(2*pi) for backward
48 * FFT. It is aligned with the usual definition of the (continuous) Fourier transform
49 * 1/sqrt(2*pi)*int f(x)*e^-ikx*dx, and thus may be relevant for spectral analysis applications.
50 */
51};
52
53} // namespace ddc
54
55namespace ddc::detail::fft {
56
57template <typename T>
58struct real_type
59{
60 using type = T;
61};
62
63template <typename T>
64struct real_type<Kokkos::complex<T>>
65{
66 using type = T;
67};
68
69template <typename T>
70using real_type_t = typename real_type<T>::type;
71
72// is_complex : trait to determine if type is Kokkos::complex<something>
73template <typename T>
74struct is_complex : std::false_type
75{
76};
77
78template <typename T>
79struct is_complex<Kokkos::complex<T>> : std::true_type
80{
81};
82
83template <typename T>
84constexpr bool is_complex_v = is_complex<T>::value;
85
86/*
87 * @brief A structure embedding the configuration of the impl FFT function: direction and type of normalization.
88 *
89 * @see FFT_impl
90 */
91struct kwArgs_impl
92{
94 direction; // Only effective for C2C transform and for normalization BACKWARD and FORWARD
95 ddc::FFT_Normalization normalization;
96};
97
98template <typename... DDimX>
99KokkosFFT::axis_type<sizeof...(DDimX)> axes()
100{
101 return KokkosFFT::axis_type<sizeof...(DDimX)> {
102 static_cast<int>(ddc::type_seq_rank_v<DDimX, ddc::detail::TypeSeq<DDimX...>>)...};
103}
104
105inline KokkosFFT::Normalization ddc_fft_normalization_to_kokkos_fft(
106 FFT_Normalization const ddc_fft_normalization)
107{
108 if (ddc_fft_normalization == ddc::FFT_Normalization::OFF
109 || ddc_fft_normalization == ddc::FFT_Normalization::FULL) {
110 return KokkosFFT::Normalization::none;
111 }
112
113 if (ddc_fft_normalization == ddc::FFT_Normalization::FORWARD) {
114 return KokkosFFT::Normalization::forward;
115 }
116
117 if (ddc_fft_normalization == ddc::FFT_Normalization::BACKWARD) {
118 return KokkosFFT::Normalization::backward;
119 }
120
121 if (ddc_fft_normalization == ddc::FFT_Normalization::ORTHO) {
122 return KokkosFFT::Normalization::ortho;
123 }
124
125 throw std::runtime_error("ddc::FFT_Normalization not handled");
126}
127
128template <
129 typename ExecSpace,
130 typename ElementType,
131 typename DDom,
132 typename Layout,
133 typename MemorySpace,
134 typename T>
135void rescale(
136 ExecSpace const& exec_space,
137 ddc::ChunkSpan<ElementType, DDom, Layout, MemorySpace> const& chunk_span,
138 T const& value)
139{
140 ddc::parallel_for_each(
141 "ddc_fft_normalization",
142 exec_space,
143 chunk_span.domain(),
144 KOKKOS_LAMBDA(typename DDom::discrete_element_type const i) {
145 chunk_span(i) *= value;
146 });
147}
148
149template <class DDim>
150Real forward_full_norm_coef(DiscreteDomain<DDim> const& ddom) noexcept
151{
152 return rlength(ddom) / Kokkos::sqrt(2 * Kokkos::numbers::pi_v<Real>)
153 / (ddom.extents() - 1).value();
154}
155
156template <class DDim>
157Real backward_full_norm_coef(DiscreteDomain<DDim> const& ddom) noexcept
158{
159 return 1 / (forward_full_norm_coef(ddom) * ddom.extents().value());
160}
161
162/// @brief Core internal function to perform the FFT.
163template <
164 typename Tin,
165 typename Tout,
166 typename ExecSpace,
167 typename MemorySpace,
168 typename LayoutIn,
169 typename LayoutOut,
170 typename... DDimIn,
171 typename... DDimOut>
172void impl(
173 ExecSpace const& exec_space,
174 ddc::ChunkSpan<Tin, ddc::DiscreteDomain<DDimIn...>, LayoutIn, MemorySpace> const& in,
175 ddc::ChunkSpan<Tout, ddc::DiscreteDomain<DDimOut...>, LayoutOut, MemorySpace> const& out,
176 kwArgs_impl const& kwargs)
177{
178 static_assert(
179 std::is_same_v<real_type_t<Tin>, float> || std::is_same_v<real_type_t<Tin>, double>,
180 "Base type of Tin (and Tout) must be float or double.");
181 static_assert(
182 std::is_same_v<real_type_t<Tin>, real_type_t<Tout>>,
183 "Types Tin and Tout must be based on same type (float or double)");
184 static_assert(
185 Kokkos::SpaceAccessibility<ExecSpace, MemorySpace>::accessible,
186 "MemorySpace has to be accessible for ExecutionSpace.");
187
188 Kokkos::View<
189 ddc::detail::mdspan_to_kokkos_element_t<Tin, sizeof...(DDimIn)>,
190 ddc::detail::mdspan_to_kokkos_layout_t<LayoutIn>,
191 MemorySpace> const in_view
192 = in.allocation_kokkos_view();
193 Kokkos::View<
194 ddc::detail::mdspan_to_kokkos_element_t<Tout, sizeof...(DDimIn)>,
195 ddc::detail::mdspan_to_kokkos_layout_t<LayoutOut>,
196 MemorySpace> const out_view
197 = out.allocation_kokkos_view();
198 KokkosFFT::Normalization const kokkos_fft_normalization
199 = ddc_fft_normalization_to_kokkos_fft(kwargs.normalization);
200
201 // C2C
202 if constexpr (std::is_same_v<Tin, Tout>) {
203 if (kwargs.direction == ddc::FFT_Direction::FORWARD) {
204 KokkosFFT::
205 fftn(exec_space,
206 in_view,
207 out_view,
208 axes<DDimIn...>(),
209 kokkos_fft_normalization);
210 } else {
211 KokkosFFT::
212 ifftn(exec_space,
213 in_view,
214 out_view,
215 axes<DDimIn...>(),
216 kokkos_fft_normalization);
217 }
218 // R2C & C2R
219 } else {
220 if constexpr (is_complex_v<Tout>) {
221 assert(kwargs.direction == ddc::FFT_Direction::FORWARD);
222 KokkosFFT::
223 rfftn(exec_space,
224 in_view,
225 out_view,
226 axes<DDimIn...>(),
227 kokkos_fft_normalization);
228 } else {
229 assert(kwargs.direction == ddc::FFT_Direction::BACKWARD);
230 KokkosFFT::
231 irfftn(exec_space,
232 in_view,
233 out_view,
234 axes<DDimIn...>(),
235 kokkos_fft_normalization);
236 }
237 }
238
239 // The FULL normalization is mesh-dependant and thus handled by DDC
240 if (kwargs.normalization == ddc::FFT_Normalization::FULL) {
241 Real norm_coef;
242 if (kwargs.direction == ddc::FFT_Direction::FORWARD) {
243 DiscreteDomain<DDimIn...> const ddom_in = in.domain();
244 norm_coef = (forward_full_norm_coef(DiscreteDomain<DDimIn>(ddom_in)) * ...);
245 } else {
246 DiscreteDomain<DDimOut...> const ddom_out = out.domain();
247 norm_coef = (backward_full_norm_coef(DiscreteDomain<DDimOut>(ddom_out)) * ...);
248 }
249
250 rescale(exec_space, out, static_cast<real_type_t<Tout>>(norm_coef));
251 }
252}
253
254} // namespace ddc::detail::fft
255
256namespace ddc {
257
258/**
259 * @brief Initialize a Fourier discrete dimension.
260 *
261 * Initialize the (1D) discrete space representing the Fourier discrete dimension associated
262 * to the (1D) mesh passed as argument. It is a N-periodic PeriodicSampling with a periodic window of width 2*pi/dx.
263 *
264 * This value comes from the Nyquist-Shannon theorem: the period of the spectral domain is N*dk = 2*pi/dx.
265 * Adding to this the relations dx = (xmax-xmin)/(N-1), and dk = (kmax-kmin)/(N-1), we get kmax-kmin = 2*pi*(N-1)^2/N/(xmax-xmin),
266 * which is used in the implementation (xmax, xmin, kmin and kmax are the centers of lower and upper cells inside a single period of the meshes).
267 *
268 * @tparam DDimFx A PeriodicSampling representing the Fourier discrete dimension.
269 * @tparam DDimX The type of the original discrete dimension.
270 *
271 * @param x_mesh The DiscreteDomain representing the (1D) original mesh.
272 *
273 * @return The initialized Impl representing the discrete Fourier space.
274 *
275 * @see PeriodicSampling
276 */
277template <typename DDimFx, typename DDimX>
278typename DDimFx::template Impl<DDimFx, Kokkos::HostSpace> init_fourier_space(
279 ddc::DiscreteDomain<DDimX> x_mesh)
280{
281 static_assert(
282 is_uniform_point_sampling_v<DDimX>,
283 "DDimX dimension must derive from UniformPointSampling");
284 static_assert(
285 is_periodic_sampling_v<DDimFx>,
286 "DDimFx dimension must derive from PeriodicSampling");
287 using CDimFx = typename DDimFx::continuous_dimension_type;
288 using CDimX = typename DDimX::continuous_dimension_type;
289 static_assert(
290 std::is_same_v<CDimFx, ddc::Fourier<CDimX>>,
291 "DDimX and DDimFx dimensions must be defined over the same continuous dimension");
292
293 DiscreteVectorElement const nx = get<DDimX>(x_mesh.extents());
294 double const lx = ddc::rlength(x_mesh);
295 auto [impl, ddom] = DDimFx::template init<DDimFx>(
296 ddc::Coordinate<CDimFx>(0),
297 ddc::Coordinate<CDimFx>(2 * (nx - 1) * (nx - 1) / (nx * lx) * Kokkos::numbers::pi),
298 ddc::DiscreteVector<DDimFx>(nx),
299 ddc::DiscreteVector<DDimFx>(nx));
300 return std::move(impl);
301}
302
303/**
304 * @brief Get the Fourier mesh.
305 *
306 * Compute the Fourier (or spectral) mesh on which the Discrete Fourier Transform of a
307 * discrete function is defined.
308 *
309 * @param x_mesh The DiscreteDomain representing the original mesh.
310 * @param C2C A flag indicating if a complex-to-complex DFT is going to be performed. Indeed,
311 * in this case the two meshes have same number of points, whereas for real-to-complex
312 * or complex-to-real DFT, each complex value of the Fourier-transformed function contains twice more
313 * information, and thus only half (actually Nx*Ny*(Nz/2+1) for 3D R2C FFT to take in account mode 0)
314 * values are needed (cf. DFT conjugate symmetry property for more information about this).
315 *
316 * @return The domain representing the Fourier mesh.
317 */
318template <typename... DDimFx, typename... DDimX>
319ddc::DiscreteDomain<DDimFx...> fourier_mesh(ddc::DiscreteDomain<DDimX...> x_mesh, bool C2C)
320{
321 static_assert(
322 (is_uniform_point_sampling_v<DDimX> && ...),
323 "DDimX dimensions should derive from UniformPointSampling");
324 static_assert(
325 (is_periodic_sampling_v<DDimFx> && ...),
326 "DDimFx dimensions should derive from PeriodicPointSampling");
327 ddc::DiscreteVector<DDimX...> extents = x_mesh.extents();
328 if (!C2C) {
329 detail::array(extents).back() = detail::array(extents).back() / 2 + 1;
330 }
331 return ddc::DiscreteDomain<DDimFx...>(ddc::DiscreteDomain<DDimFx>(
332 ddc::DiscreteElement<DDimFx>(0),
333 ddc::DiscreteVector<DDimFx>(get<DDimX>(extents)))...);
334}
335
336/**
337 * @brief A structure embedding the configuration of the exposed FFT function with the type of normalization.
338 *
339 * @see fft, ifft
340 */
341struct kwArgs_fft
342{
344 normalization; ///< Enum member to identify the type of normalization performed
345};
346
347/**
348 * @brief Perform a direct Fast Fourier Transform.
349 *
350 * Compute the discrete Fourier transform of a function using the specialized implementation for the Kokkos::ExecutionSpace
351 * of the FFT algorithm.
352 *
353 * @tparam Tin The type of the input elements (float, Kokkos::complex<float>, double or Kokkos::complex<double>).
354 * @tparam Tout The type of the output elements (Kokkos::complex<float> or Kokkos::complex<double>).
355 * @tparam DDimFx... The parameter pack of the Fourier discrete dimensions.
356 * @tparam DDimX... The parameter pack of the original discrete dimensions.
357 * @tparam ExecSpace The type of the Kokkos::ExecutionSpace on which the FFT is performed. It determines which specialized
358 * backend is used (ie. fftw, cuFFT...).
359 * @tparam MemorySpace The type of the Kokkos::MemorySpace on which are stored the input and output discrete functions.
360 * @tparam LayoutIn The layout of the Chunkspan representing the input discrete function.
361 * @tparam LayoutOut The layout of the Chunkspan representing the output discrete function.
362 *
363 * @param exec_space The Kokkos::ExecutionSpace on which the FFT is performed.
364 * @param out The output discrete function, represented as a ChunkSpan storing values on a spectral mesh.
365 * @param in The input discrete function, represented as a ChunkSpan storing values on a mesh.
366 * @param kwargs The kwArgs_fft configuring the FFT.
367 */
368template <
369 typename Tin,
370 typename Tout,
371 typename... DDimFx,
372 typename... DDimX,
373 typename ExecSpace,
374 typename MemorySpace,
375 typename LayoutIn,
376 typename LayoutOut>
377void fft(
378 ExecSpace const& exec_space,
379 ddc::ChunkSpan<Tout, ddc::DiscreteDomain<DDimFx...>, LayoutOut, MemorySpace> out,
380 ddc::ChunkSpan<Tin, ddc::DiscreteDomain<DDimX...>, LayoutIn, MemorySpace> in,
382{
383 static_assert(
384 std::is_same_v<LayoutIn, Kokkos::layout_right>
385 && std::is_same_v<LayoutOut, Kokkos::layout_right>,
386 "Layouts must be right-handed");
387 static_assert(
388 (is_uniform_point_sampling_v<DDimX> && ...),
389 "DDimX dimensions should derive from UniformPointSampling");
390 static_assert(
391 (is_periodic_sampling_v<DDimFx> && ...),
392 "DDimFx dimensions should derive from PeriodicPointSampling");
393
394 ddc::detail::fft::
395 impl(exec_space, in, out, {ddc::FFT_Direction::FORWARD, kwargs.normalization});
396}
397
398/**
399 * @brief Perform an inverse Fast Fourier Transform.
400 *
401 * Compute the inverse discrete Fourier transform of a spectral function using the specialized implementation for the Kokkos::ExecutionSpace
402 * of the iFFT algorithm.
403 *
404 * @warning C2R iFFT does NOT preserve input.
405 *
406 * @tparam Tin The type of the input elements (Kokkos::complex<float> or Kokkos::complex<double>).
407 * @tparam Tout The type of the output elements (float, Kokkos::complex<float>, double or Kokkos::complex<double>).
408 * @tparam DDimX... The parameter pack of the original discrete dimensions.
409 * @tparam DDimFx... The parameter pack of the Fourier discrete dimensions.
410 * @tparam ExecSpace The type of the Kokkos::ExecutionSpace on which the iFFT is performed. It determines which specialized
411 * backend is used (ie. fftw, cuFFT...).
412 * @tparam MemorySpace The type of the Kokkos::MemorySpace on which are stored the input and output discrete functions.
413 * @tparam LayoutIn The layout of the Chunkspan representing the input discrete function.
414 * @tparam LayoutOut The layout of the Chunkspan representing the output discrete function.
415 *
416 * @param exec_space The Kokkos::ExecutionSpace on which the iFFT is performed.
417 * @param out The output discrete function, represented as a ChunkSpan storing values on a mesh.
418 * @param in The input discrete function, represented as a ChunkSpan storing values on a spectral mesh.
419 * @param kwargs The kwArgs_fft configuring the iFFT.
420 */
421template <
422 typename Tin,
423 typename Tout,
424 typename... DDimX,
425 typename... DDimFx,
426 typename ExecSpace,
427 typename MemorySpace,
428 typename LayoutIn,
429 typename LayoutOut>
430void ifft(
431 ExecSpace const& exec_space,
432 ddc::ChunkSpan<Tout, ddc::DiscreteDomain<DDimX...>, LayoutOut, MemorySpace> out,
433 ddc::ChunkSpan<Tin, ddc::DiscreteDomain<DDimFx...>, LayoutIn, MemorySpace> in,
435{
436 static_assert(
437 std::is_same_v<LayoutIn, Kokkos::layout_right>
438 && std::is_same_v<LayoutOut, Kokkos::layout_right>,
439 "Layouts must be right-handed");
440 static_assert(
441 (is_uniform_point_sampling_v<DDimX> && ...),
442 "DDimX dimensions should derive from UniformPointSampling");
443 static_assert(
444 (is_periodic_sampling_v<DDimFx> && ...),
445 "DDimFx dimensions should derive from PeriodicPointSampling");
446
447 ddc::detail::fft::
448 impl(exec_space, in, out, {ddc::FFT_Direction::BACKWARD, kwargs.normalization});
449}
450
451} // namespace ddc
friend class DiscreteDomain
KOKKOS_FUNCTION constexpr bool operator!=(DiscreteVector< OTags... > const &rhs) const noexcept
The top-level namespace of DDC.
ddc::FFT_Normalization normalization
Enum member to identify the type of normalization performed.
Definition fft.hpp:344
void ifft(ExecSpace const &exec_space, ddc::ChunkSpan< Tout, ddc::DiscreteDomain< DDimX... >, LayoutOut, MemorySpace > out, ddc::ChunkSpan< Tin, ddc::DiscreteDomain< DDimFx... >, LayoutIn, MemorySpace > in, ddc::kwArgs_fft kwargs={ddc::FFT_Normalization::OFF})
Perform an inverse Fast Fourier Transform.
Definition fft.hpp:430
void fft(ExecSpace const &exec_space, ddc::ChunkSpan< Tout, ddc::DiscreteDomain< DDimFx... >, LayoutOut, MemorySpace > out, ddc::ChunkSpan< Tin, ddc::DiscreteDomain< DDimX... >, LayoutIn, MemorySpace > in, ddc::kwArgs_fft kwargs={ddc::FFT_Normalization::OFF})
Perform a direct Fast Fourier Transform.
Definition fft.hpp:377
FFT_Normalization
A named argument to choose the type of normalization of the FFT.
Definition fft.hpp:41
@ BACKWARD
No normalization for forward FFT, multiply by 1/N for backward FFT.
@ OFF
No normalization. Un-normalized FFT is sum_j f(x_j)*e^-ikx_j.
@ ORTHO
Multiply by 1/sqrt(N)
@ FULL
Multiply by dx/sqrt(2*pi) for forward FFT and dk/sqrt(2*pi) for backward FFT.
@ FORWARD
Multiply by 1/N for forward FFT, no normalization for backward FFT.
ddc::DiscreteDomain< DDimFx... > fourier_mesh(ddc::DiscreteDomain< DDimX... > x_mesh, bool C2C)
Get the Fourier mesh.
Definition fft.hpp:319
DDimFx::template Impl< DDimFx, Kokkos::HostSpace > init_fourier_space(ddc::DiscreteDomain< DDimX > x_mesh)
Initialize a Fourier discrete dimension.
Definition fft.hpp:278
FFT_Direction
A named argument to choose the direction of the FFT.
Definition fft.hpp:31
@ BACKWARD
Backward, corresponds to inverse FFT up to normalization.
@ FORWARD
Forward, corresponds to direct FFT up to normalization.
A structure embedding the configuration of the exposed FFT function with the type of normalization.
Definition fft.hpp:342