DDC 0.10.0
Loading...
Searching...
No Matches
splines_linear_problem_pds_tridiag.cpp
1// Copyright (C) The DDC development team, see COPYRIGHT.md file
2//
3// SPDX-License-Identifier: MIT
4
5#include <cassert>
6#if !defined(NDEBUG)
7# include <cmath>
8#endif
9#include <cstddef>
10#include <stdexcept>
11#include <string>
12#include <utility>
13
14#include <Kokkos_Core.hpp>
15#include <Kokkos_DualView.hpp>
16
17#if __has_include(<mkl_lapacke.h>)
18# include <mkl_lapacke.h>
19#else
20# include <lapacke.h>
21#endif
22
23#include <KokkosBatched_Pttrs.hpp>
24#include <KokkosBatched_Util.hpp>
25
28
29namespace ddc::detail {
30
31template <class ExecSpace>
32SplinesLinearProblemPDSTridiag<ExecSpace>::SplinesLinearProblemPDSTridiag(
33 std::size_t const mat_size)
34 : SplinesLinearProblem<ExecSpace>(mat_size)
35 , m_q("q", 2, mat_size)
36{
37 Kokkos::deep_copy(m_q.view_host(), 0.);
38}
39
40template <class ExecSpace>
41SplinesLinearProblemPDSTridiag<ExecSpace>::~SplinesLinearProblemPDSTridiag() = default;
42
43template <class ExecSpace>
44double SplinesLinearProblemPDSTridiag<ExecSpace>::get_element(std::size_t i, std::size_t j) const
45{
46 assert(i < size());
47 assert(j < size());
48
49 // Indices are swapped for an element on subdiagonal
50 if (i > j) {
51 std::swap(i, j);
52 }
53
54 if (j - i < 2) {
55 return m_q.view_host()(j - i, i);
56 }
57
58 return 0.0;
59}
60
61template <class ExecSpace>
62void SplinesLinearProblemPDSTridiag<ExecSpace>::set_element(
63 std::size_t i,
64 std::size_t j,
65 double const aij)
66{
67 assert(i < size());
68 assert(j < size());
69
70 // Indices are swapped for an element on subdiagonal
71 if (i > j) {
72 std::swap(i, j);
73 }
74 if (j - i < 2) {
75 m_q.view_host()(j - i, i) = aij;
76 } else {
77 assert(std::fabs(aij) < 1e-15);
78 }
79}
80
81template <class ExecSpace>
82void SplinesLinearProblemPDSTridiag<ExecSpace>::setup_solver()
83{
84 int const info = LAPACKE_dpttrf(
85 size(),
86 m_q.view_host().data(),
87 m_q.view_host().data() + m_q.view_host().stride(0));
88 if (info != 0) {
89 throw std::runtime_error("LAPACKE_dpttrf failed with error code " + std::to_string(info));
90 }
91
92 // Push on device
93 m_q.modify_host();
94 m_q.sync_device();
95}
96
97template <class ExecSpace>
98void SplinesLinearProblemPDSTridiag<ExecSpace>::solve(MultiRHS const b, bool const) const
99{
100 assert(b.extent(0) == size());
101 auto q_device = m_q.view_device();
102 auto d = Kokkos::subview(q_device, 0, Kokkos::ALL);
103 auto e = Kokkos::subview(q_device, 1, Kokkos::pair<int, int>(0, q_device.extent_int(1) - 1));
104 Kokkos::RangePolicy<ExecSpace> const policy(0, b.extent(1));
105 Kokkos::parallel_for(
106 "pttrs",
107 policy,
108 KOKKOS_LAMBDA(int const i) {
109 auto sub_b = Kokkos::subview(b, Kokkos::ALL, i);
110 KokkosBatched::SerialPttrs<
111 KokkosBatched::Uplo::Lower,
112 KokkosBatched::Algo::Pttrs::Unblocked>::invoke(d, e, sub_b);
113 });
114}
115
116#if defined(KOKKOS_ENABLE_SERIAL)
117template class SplinesLinearProblemPDSTridiag<Kokkos::Serial>;
118#endif
119#if defined(KOKKOS_ENABLE_OPENMP)
120template class SplinesLinearProblemPDSTridiag<Kokkos::OpenMP>;
121#endif
122#if defined(KOKKOS_ENABLE_CUDA)
123template class SplinesLinearProblemPDSTridiag<Kokkos::Cuda>;
124#endif
125#if defined(KOKKOS_ENABLE_HIP)
126template class SplinesLinearProblemPDSTridiag<Kokkos::HIP>;
127#endif
128#if defined(KOKKOS_ENABLE_SYCL)
129template class SplinesLinearProblemPDSTridiag<Kokkos::SYCL>;
130#endif
131
132} // namespace ddc::detail
The top-level namespace of DDC.