14#include <Kokkos_Core.hpp>
16#if __has_include
(<mkl_lapacke.h>)
17# include <mkl_lapacke.h>
22#include <KokkosBatched_Gbtrs.hpp>
23#include <KokkosBatched_Util.hpp>
28namespace ddc::detail {
30template <
class ExecSpace>
31SplinesLinearProblemBand<ExecSpace>::SplinesLinearProblemBand(
32 std::size_t
const mat_size,
35 : SplinesLinearProblem<ExecSpace>(mat_size)
39
40
41
42 , m_q(
"q", 2 * kl + ku + 1, mat_size)
43 , m_ipiv(
"ipiv", mat_size)
45 assert(m_kl <= mat_size);
46 assert(m_ku <= mat_size);
48 Kokkos::deep_copy(m_q.view_host(), 0.);
51template <
class ExecSpace>
52SplinesLinearProblemBand<ExecSpace>::~SplinesLinearProblemBand() =
default;
54template <
class ExecSpace>
55std::size_t SplinesLinearProblemBand<ExecSpace>::band_storage_row_index(
57 std::size_t
const j)
const
59 return m_kl + m_ku + i - j;
62template <
class ExecSpace>
63double SplinesLinearProblemBand<ExecSpace>::get_element(std::size_t
const i, std::size_t
const j)
69
70
71
72
73
75 max(
static_cast<std::ptrdiff_t>(0),
76 static_cast<std::ptrdiff_t>(j) -
static_cast<std::ptrdiff_t>(m_ku))
77 && i < std::min(size(), j + m_kl + 1)) {
78 return m_q.view_host()(band_storage_row_index(i, j), j);
84template <
class ExecSpace>
85void SplinesLinearProblemBand<ExecSpace>::set_element(
93
94
95
96
97
99 max(
static_cast<std::ptrdiff_t>(0),
100 static_cast<std::ptrdiff_t>(j) -
static_cast<std::ptrdiff_t>(m_ku))
101 && i < std::min(size(), j + m_kl + 1)) {
102 m_q.view_host()(band_storage_row_index(i, j), j) = aij;
104 assert(std::fabs(aij) < 1e-15);
108template <
class ExecSpace>
109void SplinesLinearProblemBand<ExecSpace>::setup_solver()
111 int const info = LAPACKE_dgbtrf(
117 m_q.view_host().data(),
118 m_q.view_host().stride(
120 m_ipiv.view_host().data());
122 throw std::runtime_error(
"LAPACKE_dgbtrf failed with error code " + std::to_string(info));
126 for (std::size_t i = 0; i < size(); ++i) {
127 m_ipiv.view_host()(i) -= 1;
133 m_ipiv.modify_host();
134 m_ipiv.sync_device();
137template <
class ExecSpace>
138void SplinesLinearProblemBand<ExecSpace>::solve(MultiRHS
const b,
bool const transpose)
const
140 assert(b.extent(0) == size());
142 std::size_t
const kl_proxy = m_kl;
143 std::size_t
const ku_proxy = m_ku;
144 auto q_device = m_q.view_device();
145 auto ipiv_device = m_ipiv.view_device();
146 Kokkos::RangePolicy<ExecSpace>
const policy(0, b.extent(1));
148 Kokkos::parallel_for(
151 KOKKOS_LAMBDA(
int const i) {
152 auto sub_b = Kokkos::subview(b, Kokkos::ALL, i);
153 KokkosBatched::SerialGbtrs<
154 KokkosBatched::Trans::Transpose,
155 KokkosBatched::Algo::Gbtrs::Unblocked>::
156 invoke(q_device, ipiv_device, sub_b, kl_proxy, ku_proxy);
159 Kokkos::parallel_for(
162 KOKKOS_LAMBDA(
int const i) {
163 auto sub_b = Kokkos::subview(b, Kokkos::ALL, i);
164 KokkosBatched::SerialGbtrs<
165 KokkosBatched::Trans::NoTranspose,
166 KokkosBatched::Algo::Gbtrs::Unblocked>::
167 invoke(q_device, ipiv_device, sub_b, kl_proxy, ku_proxy);
172#if defined(KOKKOS_ENABLE_SERIAL
)
173template class SplinesLinearProblemBand<Kokkos::Serial>;
175#if defined(KOKKOS_ENABLE_OPENMP)
176template class SplinesLinearProblemBand<Kokkos::OpenMP>;
178#if defined(KOKKOS_ENABLE_CUDA)
179template class SplinesLinearProblemBand<Kokkos::Cuda>;
181#if defined(KOKKOS_ENABLE_HIP)
182template class SplinesLinearProblemBand<Kokkos::HIP>;
184#if defined(KOKKOS_ENABLE_SYCL)
185template class SplinesLinearProblemBand<Kokkos::SYCL>;
The top-level namespace of DDC.