14#include <Kokkos_Core.hpp>
15#include <Kokkos_DualView.hpp>
17#if __has_include(<mkl_lapacke.h>)
18#include <mkl_lapacke.h>
23#include <KokkosBatched_Util.hpp>
25#include "kokkos-kernels-ext/KokkosBatched_Gbtrs.hpp"
27#include "splines_linear_problem.hpp"
29namespace ddc::detail {
32
33
34
35
36
37
38
39
40
41
42
43
44
45template <
class ExecSpace>
46class SplinesLinearProblemBand :
public SplinesLinearProblem<ExecSpace>
49 using typename SplinesLinearProblem<ExecSpace>::MultiRHS;
50 using SplinesLinearProblem<ExecSpace>::size;
55 Kokkos::DualView<
double**, Kokkos::LayoutRight,
typename ExecSpace::memory_space>
57 Kokkos::DualView<
int*,
typename ExecSpace::memory_space> m_ipiv;
61
62
63
64
65
66
67 explicit SplinesLinearProblemBand(
68 std::size_t
const mat_size,
71 : SplinesLinearProblem<ExecSpace>(mat_size)
75
76
77
78 , m_q(
"q", 2 * kl + ku + 1, mat_size)
79 , m_ipiv(
"ipiv", mat_size)
81 assert(m_kl <= mat_size);
82 assert(m_ku <= mat_size);
84 Kokkos::deep_copy(m_q.h_view, 0.);
88 std::size_t band_storage_row_index(std::size_t
const i, std::size_t
const j)
const
90 return m_kl + m_ku + i - j;
94 double get_element(std::size_t
const i, std::size_t
const j)
const override
99
100
101
102
103
105 max(
static_cast<std::ptrdiff_t>(0),
106 static_cast<std::ptrdiff_t>(j) -
static_cast<std::ptrdiff_t>(m_ku))
107 && i < std::min(size(), j + m_kl + 1)) {
108 return m_q.h_view(band_storage_row_index(i, j), j);
114 void set_element(std::size_t
const i, std::size_t
const j,
double const aij)
override
119
120
121
122
123
125 max(
static_cast<std::ptrdiff_t>(0),
126 static_cast<std::ptrdiff_t>(j) -
static_cast<std::ptrdiff_t>(m_ku))
127 && i < std::min(size(), j + m_kl + 1)) {
128 m_q.h_view(band_storage_row_index(i, j), j) = aij;
130 assert(std::fabs(aij) < 1e-20);
135
136
137
138
139 void setup_solver()
override
141 int const info = LAPACKE_dgbtrf(
150 m_ipiv.h_view.data());
152 throw std::runtime_error(
153 "LAPACKE_dgbtrf failed with error code " + std::to_string(info));
157 for (
int i = 0; i < size(); ++i) {
158 m_ipiv.h_view(i) -= 1;
164 m_ipiv.modify_host();
165 m_ipiv.sync_device();
169
170
171
172
173
174
175
176 void solve(MultiRHS
const b,
bool const transpose)
const override
178 assert(b.extent(0) == size());
180 std::size_t
const kl_proxy = m_kl;
181 std::size_t
const ku_proxy = m_ku;
182 auto q_device = m_q.d_view;
183 auto ipiv_device = m_ipiv.d_view;
184 Kokkos::RangePolicy<ExecSpace>
const policy(0, b.extent(1));
186 Kokkos::parallel_for(
189 KOKKOS_LAMBDA(
const int i) {
190 auto sub_b = Kokkos::subview(b, Kokkos::ALL, i);
191 KokkosBatched::SerialGbtrs<
192 KokkosBatched::Trans::Transpose,
193 KokkosBatched::Algo::Level3::Unblocked>::
194 invoke(q_device, sub_b, ipiv_device, kl_proxy, ku_proxy);
197 Kokkos::parallel_for(
200 KOKKOS_LAMBDA(
const int i) {
201 auto sub_b = Kokkos::subview(b, Kokkos::ALL, i);
202 KokkosBatched::SerialGbtrs<
203 KokkosBatched::Trans::NoTranspose,
204 KokkosBatched::Algo::Level3::Unblocked>::
205 invoke(q_device, sub_b, ipiv_device, kl_proxy, ku_proxy);
The top-level namespace of DDC.