DDC 0.11.0
Loading...
Searching...
No Matches
splines_linear_problem_band.cpp
1// Copyright (C) The DDC development team, see COPYRIGHT.md file
2//
3// SPDX-License-Identifier: MIT
4
5#include <algorithm>
6#include <cassert>
7#if !defined(NDEBUG)
8# include <cmath>
9#endif
10#include <cstddef>
11#include <stdexcept>
12#include <string>
13
14#include <Kokkos_Core.hpp>
15
16#if __has_include(<mkl_lapacke.h>)
17# include <mkl_lapacke.h>
18#else
19# include <lapacke.h>
20#endif
21
22#include <KokkosBatched_Gbtrs.hpp>
23#include <KokkosBatched_Util.hpp>
24
27
28namespace ddc::detail {
29
30template <class ExecSpace>
31SplinesLinearProblemBand<ExecSpace>::SplinesLinearProblemBand(
32 std::size_t const mat_size,
33 std::size_t const kl,
34 std::size_t const ku)
35 : SplinesLinearProblem<ExecSpace>(mat_size)
36 , m_kl(kl)
37 , m_ku(ku)
38 /*
39 * The matrix itself stored in band format requires a (kl + ku + 1)*mat_size
40 * allocation, but the LU-factorization requires an additional kl*mat_size block
41 */
42 , m_q("q", 2 * kl + ku + 1, mat_size)
43 , m_ipiv("ipiv", mat_size)
44{
45 assert(m_kl <= mat_size);
46 assert(m_ku <= mat_size);
47
48 Kokkos::deep_copy(m_q.view_host(), 0.);
49}
50
51template <class ExecSpace>
52SplinesLinearProblemBand<ExecSpace>::~SplinesLinearProblemBand() = default;
53
54template <class ExecSpace>
55std::size_t SplinesLinearProblemBand<ExecSpace>::band_storage_row_index(
56 std::size_t const i,
57 std::size_t const j) const
58{
59 return m_kl + m_ku + i - j;
60}
61
62template <class ExecSpace>
63double SplinesLinearProblemBand<ExecSpace>::get_element(std::size_t const i, std::size_t const j)
64 const
65{
66 assert(i < size());
67 assert(j < size());
68 /*
69 * The "row index" of the band format storage identify the (sub/super)-diagonal
70 * while the column index is actually the column index of the matrix. Two layouts
71 * are supported by LAPACKE. The m_kl first lines are irrelevant for the storage of
72 * the matrix itself but required for the storage of its LU factorization.
73 */
74 if (i >= std::
75 max(static_cast<std::ptrdiff_t>(0),
76 static_cast<std::ptrdiff_t>(j) - static_cast<std::ptrdiff_t>(m_ku))
77 && i < std::min(size(), j + m_kl + 1)) {
78 return m_q.view_host()(band_storage_row_index(i, j), j);
79 }
80
81 return 0.0;
82}
83
84template <class ExecSpace>
85void SplinesLinearProblemBand<ExecSpace>::set_element(
86 std::size_t const i,
87 std::size_t const j,
88 double const aij)
89{
90 assert(i < size());
91 assert(j < size());
92 /*
93 * The "row index" of the band format storage identify the (sub/super)-diagonal
94 * while the column index is actually the column index of the matrix. Two layouts
95 * are supported by LAPACKE. The m_kl first lines are irrelevant for the storage of
96 * the matrix itself but required for the storage of its LU factorization.
97 */
98 if (i >= std::
99 max(static_cast<std::ptrdiff_t>(0),
100 static_cast<std::ptrdiff_t>(j) - static_cast<std::ptrdiff_t>(m_ku))
101 && i < std::min(size(), j + m_kl + 1)) {
102 m_q.view_host()(band_storage_row_index(i, j), j) = aij;
103 } else {
104 assert(std::fabs(aij) < 1e-15);
105 }
106}
107
108template <class ExecSpace>
109void SplinesLinearProblemBand<ExecSpace>::setup_solver()
110{
111 int const info = LAPACKE_dgbtrf(
112 LAPACK_ROW_MAJOR,
113 size(),
114 size(),
115 m_kl,
116 m_ku,
117 m_q.view_host().data(),
118 m_q.view_host().stride(
119 0), // m_q.view_host().stride(0) if LAPACK_ROW_MAJOR, m_q.view_host().stride(1) if LAPACK_COL_MAJOR
120 m_ipiv.view_host().data());
121 if (info != 0) {
122 throw std::runtime_error("LAPACKE_dgbtrf failed with error code " + std::to_string(info));
123 }
124
125 // Convert 1-based index to 0-based index
126 for (std::size_t i = 0; i < size(); ++i) {
127 m_ipiv.view_host()(i) -= 1;
128 }
129
130 // Push on device
131 m_q.modify_host();
132 m_q.sync_device();
133 m_ipiv.modify_host();
134 m_ipiv.sync_device();
135}
136
137template <class ExecSpace>
138void SplinesLinearProblemBand<ExecSpace>::solve(MultiRHS const b, bool const transpose) const
139{
140 assert(b.extent(0) == size());
141
142 std::size_t const kl_proxy = m_kl;
143 std::size_t const ku_proxy = m_ku;
144 auto q_device = m_q.view_device();
145 auto ipiv_device = m_ipiv.view_device();
146 Kokkos::RangePolicy<ExecSpace> const policy(0, b.extent(1));
147 if (transpose) {
148 Kokkos::parallel_for(
149 "gbtrs",
150 policy,
151 KOKKOS_LAMBDA(int const i) {
152 auto sub_b = Kokkos::subview(b, Kokkos::ALL, i);
153 KokkosBatched::SerialGbtrs<
154 KokkosBatched::Trans::Transpose,
155 KokkosBatched::Algo::Gbtrs::Unblocked>::
156 invoke(q_device, ipiv_device, sub_b, kl_proxy, ku_proxy);
157 });
158 } else {
159 Kokkos::parallel_for(
160 "gbtrs",
161 policy,
162 KOKKOS_LAMBDA(int const i) {
163 auto sub_b = Kokkos::subview(b, Kokkos::ALL, i);
164 KokkosBatched::SerialGbtrs<
165 KokkosBatched::Trans::NoTranspose,
166 KokkosBatched::Algo::Gbtrs::Unblocked>::
167 invoke(q_device, ipiv_device, sub_b, kl_proxy, ku_proxy);
168 });
169 }
170}
171
172#if defined(KOKKOS_ENABLE_SERIAL)
173template class SplinesLinearProblemBand<Kokkos::Serial>;
174#endif
175#if defined(KOKKOS_ENABLE_OPENMP)
176template class SplinesLinearProblemBand<Kokkos::OpenMP>;
177#endif
178#if defined(KOKKOS_ENABLE_CUDA)
179template class SplinesLinearProblemBand<Kokkos::Cuda>;
180#endif
181#if defined(KOKKOS_ENABLE_HIP)
182template class SplinesLinearProblemBand<Kokkos::HIP>;
183#endif
184#if defined(KOKKOS_ENABLE_SYCL)
185template class SplinesLinearProblemBand<Kokkos::SYCL>;
186#endif
187
188} // namespace ddc::detail
The top-level namespace of DDC.