16#include <Kokkos_Core.hpp>
17#include <Kokkos_DualView.hpp>
19#if __has_include(<mkl_lapacke.h>)
20# include <mkl_lapacke.h>
25#include <KokkosBatched_Util.hpp>
27#include "kokkos-kernels-ext/KokkosBatched_Gbtrs.hpp"
29#include "splines_linear_problem.hpp"
31namespace ddc::detail {
34
35
36
37
38
39
40
41
42
43
44
45
46
47template <
class ExecSpace>
48class SplinesLinearProblemBand :
public SplinesLinearProblem<ExecSpace>
51 using typename SplinesLinearProblem<ExecSpace>::MultiRHS;
52 using SplinesLinearProblem<ExecSpace>::size;
57 Kokkos::DualView<
double**, Kokkos::LayoutRight,
typename ExecSpace::memory_space>
59 Kokkos::DualView<
int*,
typename ExecSpace::memory_space> m_ipiv;
63
64
65
66
67
68
69 explicit SplinesLinearProblemBand(
70 std::size_t
const mat_size,
73 : SplinesLinearProblem<ExecSpace>(mat_size)
77
78
79
80 , m_q(
"q", 2 * kl + ku + 1, mat_size)
81 , m_ipiv(
"ipiv", mat_size)
83 assert(m_kl <= mat_size);
84 assert(m_ku <= mat_size);
86 Kokkos::deep_copy(m_q.view_host(), 0.);
90 std::size_t band_storage_row_index(std::size_t
const i, std::size_t
const j)
const
92 return m_kl + m_ku + i - j;
96 double get_element(std::size_t
const i, std::size_t
const j)
const override
101
102
103
104
105
107 max(
static_cast<std::ptrdiff_t>(0),
108 static_cast<std::ptrdiff_t>(j) -
static_cast<std::ptrdiff_t>(m_ku))
109 && i < std::min(size(), j + m_kl + 1)) {
110 return m_q.view_host()(band_storage_row_index(i, j), j);
116 void set_element(std::size_t
const i, std::size_t
const j,
double const aij)
override
121
122
123
124
125
127 max(
static_cast<std::ptrdiff_t>(0),
128 static_cast<std::ptrdiff_t>(j) -
static_cast<std::ptrdiff_t>(m_ku))
129 && i < std::min(size(), j + m_kl + 1)) {
130 m_q.view_host()(band_storage_row_index(i, j), j) = aij;
132 assert(std::fabs(aij) < 1e-20);
137
138
139
140
141 void setup_solver()
override
143 int const info = LAPACKE_dgbtrf(
149 m_q.view_host().data(),
150 m_q.view_host().stride(
152 m_ipiv.view_host().data());
154 throw std::runtime_error(
155 "LAPACKE_dgbtrf failed with error code " + std::to_string(info));
159 for (std::size_t i = 0; i < size(); ++i) {
160 m_ipiv.view_host()(i) -= 1;
166 m_ipiv.modify_host();
167 m_ipiv.sync_device();
171
172
173
174
175
176
177
178 void solve(MultiRHS
const b,
bool const transpose)
const override
180 assert(b.extent(0) == size());
182 std::size_t
const kl_proxy = m_kl;
183 std::size_t
const ku_proxy = m_ku;
184 auto q_device = m_q.view_device();
185 auto ipiv_device = m_ipiv.view_device();
186 Kokkos::RangePolicy<ExecSpace>
const policy(0, b.extent(1));
188 Kokkos::parallel_for(
191 KOKKOS_LAMBDA(
int const i) {
192 auto sub_b = Kokkos::subview(b, Kokkos::ALL, i);
193 KokkosBatched::SerialGbtrs<
194 KokkosBatched::Trans::Transpose,
195 KokkosBatched::Algo::Level3::Unblocked>::
196 invoke(q_device, sub_b, ipiv_device, kl_proxy, ku_proxy);
199 Kokkos::parallel_for(
202 KOKKOS_LAMBDA(
int const i) {
203 auto sub_b = Kokkos::subview(b, Kokkos::ALL, i);
204 KokkosBatched::SerialGbtrs<
205 KokkosBatched::Trans::NoTranspose,
206 KokkosBatched::Algo::Level3::Unblocked>::
207 invoke(q_device, sub_b, ipiv_device, kl_proxy, ku_proxy);
The top-level namespace of DDC.