DDC 0.10.0
Loading...
Searching...
No Matches
fft.hpp
1// Copyright (C) The DDC development team, see COPYRIGHT.md file
2//
3// SPDX-License-Identifier: MIT
4
5#pragma once
6
7#include <cassert>
8#include <stdexcept>
9#include <type_traits>
10#include <utility>
11
12#include <ddc/ddc.hpp>
13
14#include <KokkosFFT.hpp>
15#include <Kokkos_Core.hpp>
16
17namespace ddc {
18
19/**
20 * @brief A templated tag representing a continuous dimension in the Fourier space associated to the original continuous dimension.
21 *
22 * @tparam The tag representing the original dimension.
23 */
24template <typename Dim>
25struct Fourier;
26
27/**
28 * @brief A named argument to choose the direction of the FFT.
29 *
30 * @see kwArgsImpl, kwArgs_fft
31 */
32enum class FFT_Direction {
33 FORWARD, ///< Forward, corresponds to direct FFT up to normalization
34 BACKWARD ///< Backward, corresponds to inverse FFT up to normalization
35};
36
37/**
38 * @brief A named argument to choose the type of normalization of the FFT.
39 *
40 * @see kwArgsImpl, kwArgs_fft
41 */
42enum class FFT_Normalization {
43 OFF, ///< No normalization. Un-normalized FFT is sum_j f(x_j)*e^-ikx_j
44 FORWARD, ///< Multiply by 1/N for forward FFT, no normalization for backward FFT
45 BACKWARD, ///< No normalization for forward FFT, multiply by 1/N for backward FFT
46 ORTHO, ///< Multiply by 1/sqrt(N)
47 FULL /**<
48 * Multiply by dx/sqrt(2*pi) for forward FFT and dk/sqrt(2*pi) for backward
49 * FFT. It is aligned with the usual definition of the (continuous) Fourier transform
50 * 1/sqrt(2*pi)*int f(x)*e^-ikx*dx, and thus may be relevant for spectral analysis applications.
51 */
52};
53
54} // namespace ddc
55
56namespace ddc::detail::fft {
57
58template <typename T>
59struct RealType
60{
61 using type = T;
62};
63
64template <typename T>
65struct RealType<Kokkos::complex<T>>
66{
67 using type = T;
68};
69
70template <typename T>
71using real_type_t = typename RealType<T>::type;
72
73// is_complex : trait to determine if type is Kokkos::complex<something>
74template <typename T>
75struct is_complex : std::false_type
76{
77};
78
79template <typename T>
80struct is_complex<Kokkos::complex<T>> : std::true_type
81{
82};
83
84template <typename T>
85constexpr bool is_complex_v = is_complex<T>::value;
86
87/*
88 * @brief A structure embedding the configuration of the impl FFT function: direction and type of normalization.
89 *
90 * @see FFT_impl
91 */
92struct KwArgsImpl
93{
95 direction; // Only effective for C2C transform and for normalization BACKWARD and FORWARD
96 ddc::FFT_Normalization normalization;
97};
98
99template <typename... DDimX>
100KokkosFFT::axis_type<sizeof...(DDimX)> axes()
101{
102 return KokkosFFT::axis_type<sizeof...(DDimX)> {
103 static_cast<int>(ddc::type_seq_rank_v<DDimX, ddc::detail::TypeSeq<DDimX...>>)...};
104}
105
106inline KokkosFFT::Normalization ddc_fft_normalization_to_kokkos_fft(
107 FFT_Normalization const ddc_fft_normalization)
108{
109 if (ddc_fft_normalization == ddc::FFT_Normalization::OFF
110 || ddc_fft_normalization == ddc::FFT_Normalization::FULL) {
111 return KokkosFFT::Normalization::none;
112 }
113
114 if (ddc_fft_normalization == ddc::FFT_Normalization::FORWARD) {
115 return KokkosFFT::Normalization::forward;
116 }
117
118 if (ddc_fft_normalization == ddc::FFT_Normalization::BACKWARD) {
119 return KokkosFFT::Normalization::backward;
120 }
121
122 if (ddc_fft_normalization == ddc::FFT_Normalization::ORTHO) {
123 return KokkosFFT::Normalization::ortho;
124 }
125
126 throw std::runtime_error("ddc::FFT_Normalization not handled");
127}
128
129template <
130 typename ExecSpace,
131 typename ElementType,
132 typename DDom,
133 typename Layout,
134 typename MemorySpace,
135 typename T>
136void rescale(
137 ExecSpace const& exec_space,
138 ddc::ChunkSpan<ElementType, DDom, Layout, MemorySpace> const& chunk_span,
139 T const& value)
140{
141 ddc::parallel_for_each(
142 "ddc_fft_normalization",
143 exec_space,
144 chunk_span.domain(),
145 KOKKOS_LAMBDA(typename DDom::discrete_element_type const i) {
146 chunk_span(i) *= value;
147 });
148}
149
150template <class DDim>
151Real forward_full_norm_coef(DiscreteDomain<DDim> const& ddom) noexcept
152{
153 return rlength(ddom) / Kokkos::sqrt(2 * Kokkos::numbers::pi_v<Real>)
154 / (ddom.extents() - 1).value();
155}
156
157template <class DDim>
158Real backward_full_norm_coef(DiscreteDomain<DDim> const& ddom) noexcept
159{
160 return 1 / (forward_full_norm_coef(ddom) * ddom.extents().value());
161}
162
163/// @brief Core internal function to perform the FFT.
164template <
165 typename Tin,
166 typename Tout,
167 typename ExecSpace,
168 typename MemorySpace,
169 typename LayoutIn,
170 typename LayoutOut,
171 typename... DDimIn,
172 typename... DDimOut>
173void impl(
174 ExecSpace const& exec_space,
175 ddc::ChunkSpan<Tin, ddc::DiscreteDomain<DDimIn...>, LayoutIn, MemorySpace> const& in,
176 ddc::ChunkSpan<Tout, ddc::DiscreteDomain<DDimOut...>, LayoutOut, MemorySpace> const& out,
177 KwArgsImpl const& kwargs)
178{
179 static_assert(
180 std::is_same_v<real_type_t<Tin>, float> || std::is_same_v<real_type_t<Tin>, double>,
181 "Base type of Tin (and Tout) must be float or double.");
182 static_assert(
183 std::is_same_v<real_type_t<Tin>, real_type_t<Tout>>,
184 "Types Tin and Tout must be based on same type (float or double)");
185 static_assert(
186 Kokkos::SpaceAccessibility<ExecSpace, MemorySpace>::accessible,
187 "MemorySpace has to be accessible for ExecutionSpace.");
188
189 Kokkos::View<
190 ddc::detail::mdspan_to_kokkos_element_t<Tin, sizeof...(DDimIn)>,
191 ddc::detail::mdspan_to_kokkos_layout_t<LayoutIn>,
192 MemorySpace> const in_view
193 = in.allocation_kokkos_view();
194 Kokkos::View<
195 ddc::detail::mdspan_to_kokkos_element_t<Tout, sizeof...(DDimIn)>,
196 ddc::detail::mdspan_to_kokkos_layout_t<LayoutOut>,
197 MemorySpace> const out_view
198 = out.allocation_kokkos_view();
199 KokkosFFT::Normalization const kokkos_fft_normalization
200 = ddc_fft_normalization_to_kokkos_fft(kwargs.normalization);
201
202 // C2C
203 if constexpr (std::is_same_v<Tin, Tout>) {
204 if (kwargs.direction == ddc::FFT_Direction::FORWARD) {
205 KokkosFFT::
206 fftn(exec_space,
207 in_view,
208 out_view,
209 axes<DDimIn...>(),
210 kokkos_fft_normalization);
211 } else {
212 KokkosFFT::
213 ifftn(exec_space,
214 in_view,
215 out_view,
216 axes<DDimIn...>(),
217 kokkos_fft_normalization);
218 }
219 // R2C & C2R
220 } else {
221 if constexpr (is_complex_v<Tout>) {
222 assert(kwargs.direction == ddc::FFT_Direction::FORWARD);
223 KokkosFFT::
224 rfftn(exec_space,
225 in_view,
226 out_view,
227 axes<DDimIn...>(),
228 kokkos_fft_normalization);
229 } else {
230 assert(kwargs.direction == ddc::FFT_Direction::BACKWARD);
231 KokkosFFT::
232 irfftn(exec_space,
233 in_view,
234 out_view,
235 axes<DDimIn...>(),
236 kokkos_fft_normalization);
237 }
238 }
239
240 // The FULL normalization is mesh-dependant and thus handled by DDC
241 if (kwargs.normalization == ddc::FFT_Normalization::FULL) {
242 Real norm_coef;
243 if (kwargs.direction == ddc::FFT_Direction::FORWARD) {
244 DiscreteDomain<DDimIn...> const ddom_in = in.domain();
245 norm_coef = (forward_full_norm_coef(DiscreteDomain<DDimIn>(ddom_in)) * ...);
246 } else {
247 DiscreteDomain<DDimOut...> const ddom_out = out.domain();
248 norm_coef = (backward_full_norm_coef(DiscreteDomain<DDimOut>(ddom_out)) * ...);
249 }
250
251 rescale(exec_space, out, static_cast<real_type_t<Tout>>(norm_coef));
252 }
253}
254
255} // namespace ddc::detail::fft
256
257namespace ddc {
258
259/**
260 * @brief Initialize a Fourier discrete dimension.
261 *
262 * Initialize the (1D) discrete space representing the Fourier discrete dimension associated
263 * to the (1D) mesh passed as argument. It is a N-periodic PeriodicSampling with a periodic window of width 2*pi/dx.
264 *
265 * This value comes from the Nyquist-Shannon theorem: the period of the spectral domain is N*dk = 2*pi/dx.
266 * Adding to this the relations dx = (xmax-xmin)/(N-1), and dk = (kmax-kmin)/(N-1), we get kmax-kmin = 2*pi*(N-1)^2/N/(xmax-xmin),
267 * which is used in the implementation (xmax, xmin, kmin and kmax are the centers of lower and upper cells inside a single period of the meshes).
268 *
269 * @tparam DDimFx A PeriodicSampling representing the Fourier discrete dimension.
270 * @tparam DDimX The type of the original discrete dimension.
271 *
272 * @param x_mesh The DiscreteDomain representing the (1D) original mesh.
273 *
274 * @return The initialized Impl representing the discrete Fourier space.
275 *
276 * @see PeriodicSampling
277 */
278template <typename DDimFx, typename DDimX>
279typename DDimFx::template Impl<DDimFx, Kokkos::HostSpace> init_fourier_space(
280 ddc::DiscreteDomain<DDimX> x_mesh)
281{
282 static_assert(
283 is_uniform_point_sampling_v<DDimX>,
284 "DDimX dimension must derive from UniformPointSampling");
285 static_assert(
286 is_periodic_sampling_v<DDimFx>,
287 "DDimFx dimension must derive from PeriodicSampling");
288 using CDimFx = typename DDimFx::continuous_dimension_type;
289 using CDimX = typename DDimX::continuous_dimension_type;
290 static_assert(
291 std::is_same_v<CDimFx, ddc::Fourier<CDimX>>,
292 "DDimX and DDimFx dimensions must be defined over the same continuous dimension");
293
294 DiscreteVectorElement const nx = get<DDimX>(x_mesh.extents());
295 double const lx = ddc::rlength(x_mesh);
296 auto [impl, ddom] = DDimFx::template init<DDimFx>(
297 ddc::Coordinate<CDimFx>(0),
298 ddc::Coordinate<CDimFx>(2 * (nx - 1) * (nx - 1) / (nx * lx) * Kokkos::numbers::pi),
299 ddc::DiscreteVector<DDimFx>(nx),
300 ddc::DiscreteVector<DDimFx>(nx));
301 return std::move(impl);
302}
303
304/**
305 * @brief Get the Fourier mesh.
306 *
307 * Compute the Fourier (or spectral) mesh on which the Discrete Fourier Transform of a
308 * discrete function is defined.
309 *
310 * @param x_mesh The DiscreteDomain representing the original mesh.
311 * @param C2C A flag indicating if a complex-to-complex DFT is going to be performed. Indeed,
312 * in this case the two meshes have same number of points, whereas for real-to-complex
313 * or complex-to-real DFT, each complex value of the Fourier-transformed function contains twice more
314 * information, and thus only half (actually Nx*Ny*(Nz/2+1) for 3D R2C FFT to take in account mode 0)
315 * values are needed (cf. DFT conjugate symmetry property for more information about this).
316 *
317 * @return The domain representing the Fourier mesh.
318 */
319template <typename... DDimFx, typename... DDimX>
320ddc::DiscreteDomain<DDimFx...> fourier_mesh(ddc::DiscreteDomain<DDimX...> x_mesh, bool C2C)
321{
322 static_assert(
323 (is_uniform_point_sampling_v<DDimX> && ...),
324 "DDimX dimensions should derive from UniformPointSampling");
325 static_assert(
326 (is_periodic_sampling_v<DDimFx> && ...),
327 "DDimFx dimensions should derive from PeriodicPointSampling");
328 ddc::DiscreteVector<DDimX...> extents = x_mesh.extents();
329 if (!C2C) {
330 detail::array(extents).back() = detail::array(extents).back() / 2 + 1;
331 }
332 return ddc::DiscreteDomain<DDimFx...>(ddc::DiscreteDomain<DDimFx>(
333 ddc::DiscreteElement<DDimFx>(0),
334 ddc::DiscreteVector<DDimFx>(get<DDimX>(extents)))...);
335}
336
337/**
338 * @brief A structure embedding the configuration of the exposed FFT function with the type of normalization.
339 *
340 * @see fft, ifft
341 */
342struct kwArgs_fft
343{
345 normalization; ///< Enum member to identify the type of normalization performed
346};
347
348/**
349 * @brief Perform a direct Fast Fourier Transform.
350 *
351 * Compute the discrete Fourier transform of a function using the specialized implementation for the Kokkos::ExecutionSpace
352 * of the FFT algorithm.
353 *
354 * @tparam Tin The type of the input elements (float, Kokkos::complex<float>, double or Kokkos::complex<double>).
355 * @tparam Tout The type of the output elements (Kokkos::complex<float> or Kokkos::complex<double>).
356 * @tparam DDimFx... The parameter pack of the Fourier discrete dimensions.
357 * @tparam DDimX... The parameter pack of the original discrete dimensions.
358 * @tparam ExecSpace The type of the Kokkos::ExecutionSpace on which the FFT is performed. It determines which specialized
359 * backend is used (ie. fftw, cuFFT...).
360 * @tparam MemorySpace The type of the Kokkos::MemorySpace on which are stored the input and output discrete functions.
361 * @tparam LayoutIn The layout of the Chunkspan representing the input discrete function.
362 * @tparam LayoutOut The layout of the Chunkspan representing the output discrete function.
363 *
364 * @param exec_space The Kokkos::ExecutionSpace on which the FFT is performed.
365 * @param out The output discrete function, represented as a ChunkSpan storing values on a spectral mesh.
366 * @param in The input discrete function, represented as a ChunkSpan storing values on a mesh.
367 * @param kwargs The kwArgs_fft configuring the FFT.
368 */
369template <
370 typename Tin,
371 typename Tout,
372 typename... DDimFx,
373 typename... DDimX,
374 typename ExecSpace,
375 typename MemorySpace,
376 typename LayoutIn,
377 typename LayoutOut>
378void fft(
379 ExecSpace const& exec_space,
380 ddc::ChunkSpan<Tout, ddc::DiscreteDomain<DDimFx...>, LayoutOut, MemorySpace> out,
381 ddc::ChunkSpan<Tin, ddc::DiscreteDomain<DDimX...>, LayoutIn, MemorySpace> in,
383{
384 static_assert(
385 std::is_same_v<LayoutIn, Kokkos::layout_right>
386 && std::is_same_v<LayoutOut, Kokkos::layout_right>,
387 "Layouts must be right-handed");
388 static_assert(
389 (is_uniform_point_sampling_v<DDimX> && ...),
390 "DDimX dimensions should derive from UniformPointSampling");
391 static_assert(
392 (is_periodic_sampling_v<DDimFx> && ...),
393 "DDimFx dimensions should derive from PeriodicPointSampling");
394
395 ddc::detail::fft::
396 impl(exec_space, in, out, {ddc::FFT_Direction::FORWARD, kwargs.normalization});
397}
398
399/**
400 * @brief Perform an inverse Fast Fourier Transform.
401 *
402 * Compute the inverse discrete Fourier transform of a spectral function using the specialized implementation for the Kokkos::ExecutionSpace
403 * of the iFFT algorithm.
404 *
405 * @warning C2R iFFT does NOT preserve input.
406 *
407 * @tparam Tin The type of the input elements (Kokkos::complex<float> or Kokkos::complex<double>).
408 * @tparam Tout The type of the output elements (float, Kokkos::complex<float>, double or Kokkos::complex<double>).
409 * @tparam DDimX... The parameter pack of the original discrete dimensions.
410 * @tparam DDimFx... The parameter pack of the Fourier discrete dimensions.
411 * @tparam ExecSpace The type of the Kokkos::ExecutionSpace on which the iFFT is performed. It determines which specialized
412 * backend is used (ie. fftw, cuFFT...).
413 * @tparam MemorySpace The type of the Kokkos::MemorySpace on which are stored the input and output discrete functions.
414 * @tparam LayoutIn The layout of the Chunkspan representing the input discrete function.
415 * @tparam LayoutOut The layout of the Chunkspan representing the output discrete function.
416 *
417 * @param exec_space The Kokkos::ExecutionSpace on which the iFFT is performed.
418 * @param out The output discrete function, represented as a ChunkSpan storing values on a mesh.
419 * @param in The input discrete function, represented as a ChunkSpan storing values on a spectral mesh.
420 * @param kwargs The kwArgs_fft configuring the iFFT.
421 */
422template <
423 typename Tin,
424 typename Tout,
425 typename... DDimX,
426 typename... DDimFx,
427 typename ExecSpace,
428 typename MemorySpace,
429 typename LayoutIn,
430 typename LayoutOut>
431void ifft(
432 ExecSpace const& exec_space,
433 ddc::ChunkSpan<Tout, ddc::DiscreteDomain<DDimX...>, LayoutOut, MemorySpace> out,
434 ddc::ChunkSpan<Tin, ddc::DiscreteDomain<DDimFx...>, LayoutIn, MemorySpace> in,
436{
437 static_assert(
438 std::is_same_v<LayoutIn, Kokkos::layout_right>
439 && std::is_same_v<LayoutOut, Kokkos::layout_right>,
440 "Layouts must be right-handed");
441 static_assert(
442 (is_uniform_point_sampling_v<DDimX> && ...),
443 "DDimX dimensions should derive from UniformPointSampling");
444 static_assert(
445 (is_periodic_sampling_v<DDimFx> && ...),
446 "DDimFx dimensions should derive from PeriodicPointSampling");
447
448 ddc::detail::fft::
449 impl(exec_space, in, out, {ddc::FFT_Direction::BACKWARD, kwargs.normalization});
450}
451
452} // namespace ddc
friend class ChunkSpan
friend class DiscreteDomain
KOKKOS_FUNCTION constexpr bool operator!=(DiscreteVector< OTags... > const &rhs) const noexcept
The top-level namespace of DDC.
ddc::FFT_Normalization normalization
Enum member to identify the type of normalization performed.
Definition fft.hpp:345
void ifft(ExecSpace const &exec_space, ddc::ChunkSpan< Tout, ddc::DiscreteDomain< DDimX... >, LayoutOut, MemorySpace > out, ddc::ChunkSpan< Tin, ddc::DiscreteDomain< DDimFx... >, LayoutIn, MemorySpace > in, ddc::kwArgs_fft kwargs={ddc::FFT_Normalization::OFF})
Perform an inverse Fast Fourier Transform.
Definition fft.hpp:431
void fft(ExecSpace const &exec_space, ddc::ChunkSpan< Tout, ddc::DiscreteDomain< DDimFx... >, LayoutOut, MemorySpace > out, ddc::ChunkSpan< Tin, ddc::DiscreteDomain< DDimX... >, LayoutIn, MemorySpace > in, ddc::kwArgs_fft kwargs={ddc::FFT_Normalization::OFF})
Perform a direct Fast Fourier Transform.
Definition fft.hpp:378
FFT_Normalization
A named argument to choose the type of normalization of the FFT.
Definition fft.hpp:42
@ BACKWARD
No normalization for forward FFT, multiply by 1/N for backward FFT.
@ OFF
No normalization. Un-normalized FFT is sum_j f(x_j)*e^-ikx_j.
@ ORTHO
Multiply by 1/sqrt(N)
@ FULL
Multiply by dx/sqrt(2*pi) for forward FFT and dk/sqrt(2*pi) for backward FFT.
@ FORWARD
Multiply by 1/N for forward FFT, no normalization for backward FFT.
ddc::DiscreteDomain< DDimFx... > fourier_mesh(ddc::DiscreteDomain< DDimX... > x_mesh, bool C2C)
Get the Fourier mesh.
Definition fft.hpp:320
DDimFx::template Impl< DDimFx, Kokkos::HostSpace > init_fourier_space(ddc::DiscreteDomain< DDimX > x_mesh)
Initialize a Fourier discrete dimension.
Definition fft.hpp:279
FFT_Direction
A named argument to choose the direction of the FFT.
Definition fft.hpp:32
@ BACKWARD
Backward, corresponds to inverse FFT up to normalization.
@ FORWARD
Forward, corresponds to direct FFT up to normalization.
A structure embedding the configuration of the exposed FFT function with the type of normalization.
Definition fft.hpp:343