14#include <Kokkos_Core.hpp>
15#include <Kokkos_DualView.hpp>
17#if __has_include(<mkl_lapacke.h>)
18#include <mkl_lapacke.h>
23#include <KokkosBatched_Gbtrs.hpp>
24#include <KokkosBatched_Util.hpp>
26#include "splines_linear_problem.hpp"
28namespace ddc::detail {
31
32
33
34
35
36
37
38
39
40
41
42
43
44template <
class ExecSpace>
45class SplinesLinearProblemBand :
public SplinesLinearProblem<ExecSpace>
48 using typename SplinesLinearProblem<ExecSpace>::MultiRHS;
49 using SplinesLinearProblem<ExecSpace>::size;
54 Kokkos::DualView<
double**, Kokkos::LayoutRight,
typename ExecSpace::memory_space>
56 Kokkos::DualView<
int*,
typename ExecSpace::memory_space> m_ipiv;
60
61
62
63
64
65
66 explicit SplinesLinearProblemBand(
67 std::size_t
const mat_size,
70 : SplinesLinearProblem<ExecSpace>(mat_size)
74
75
76
77 , m_q(
"q", 2 * kl + ku + 1, mat_size)
78 , m_ipiv(
"ipiv", mat_size)
80 assert(m_kl <= mat_size);
81 assert(m_ku <= mat_size);
83 Kokkos::deep_copy(m_q.h_view, 0.);
87 std::size_t band_storage_row_index(std::size_t
const i, std::size_t
const j)
const
89 return m_kl + m_ku + i - j;
93 double get_element(std::size_t
const i, std::size_t
const j)
const override
98
99
100
101
102
104 max(
static_cast<std::ptrdiff_t>(0),
105 static_cast<std::ptrdiff_t>(j) -
static_cast<std::ptrdiff_t>(m_ku))
106 && i < std::min(size(), j + m_kl + 1)) {
107 return m_q.h_view(band_storage_row_index(i, j), j);
113 void set_element(std::size_t
const i, std::size_t
const j,
double const aij)
override
118
119
120
121
122
124 max(
static_cast<std::ptrdiff_t>(0),
125 static_cast<std::ptrdiff_t>(j) -
static_cast<std::ptrdiff_t>(m_ku))
126 && i < std::min(size(), j + m_kl + 1)) {
127 m_q.h_view(band_storage_row_index(i, j), j) = aij;
129 assert(std::fabs(aij) < 1e-20);
134
135
136
137
138 void setup_solver()
override
140 int const info = LAPACKE_dgbtrf(
149 m_ipiv.h_view.data());
151 throw std::runtime_error(
152 "LAPACKE_dgbtrf failed with error code " + std::to_string(info));
156 for (
int i = 0; i < size(); ++i) {
157 m_ipiv.h_view(i) -= 1;
163 m_ipiv.modify_host();
164 m_ipiv.sync_device();
168
169
170
171
172
173
174
175 void solve(MultiRHS
const b,
bool const transpose)
const override
177 assert(b.extent(0) == size());
179 std::size_t
const kl_proxy = m_kl;
180 std::size_t
const ku_proxy = m_ku;
181 auto q_device = m_q.d_view;
182 auto ipiv_device = m_ipiv.d_view;
183 Kokkos::RangePolicy<ExecSpace>
const policy(0, b.extent(1));
185 Kokkos::parallel_for(
188 KOKKOS_LAMBDA(
const int i) {
189 auto sub_b = Kokkos::subview(b, Kokkos::ALL, i);
190 KokkosBatched::SerialGbtrs<
191 KokkosBatched::Trans::Transpose,
192 KokkosBatched::Algo::Gbtrs::Unblocked>::
193 invoke(q_device, sub_b, ipiv_device, kl_proxy, ku_proxy);
196 Kokkos::parallel_for(
199 KOKKOS_LAMBDA(
const int i) {
200 auto sub_b = Kokkos::subview(b, Kokkos::ALL, i);
201 KokkosBatched::SerialGbtrs<
202 KokkosBatched::Trans::NoTranspose,
203 KokkosBatched::Algo::Gbtrs::Unblocked>::
204 invoke(q_device, sub_b, ipiv_device, kl_proxy, ku_proxy);
The top-level namespace of DDC.