14#include <Kokkos_Core.hpp>
15#include <Kokkos_DualView.hpp>
17#if __has_include
(<mkl_lapacke.h>)
18# include <mkl_lapacke.h>
23#include <KokkosBatched_Pttrs.hpp>
24#include <KokkosBatched_Util.hpp>
29namespace ddc::detail {
31template <
class ExecSpace>
32SplinesLinearProblemPDSTridiag<ExecSpace>::SplinesLinearProblemPDSTridiag(
33 std::size_t
const mat_size)
34 : SplinesLinearProblem<ExecSpace>(mat_size)
35 , m_q(
"q", 2, mat_size)
37 Kokkos::deep_copy(m_q.view_host(), 0.);
40template <
class ExecSpace>
41SplinesLinearProblemPDSTridiag<ExecSpace>::~SplinesLinearProblemPDSTridiag() =
default;
43template <
class ExecSpace>
44double SplinesLinearProblemPDSTridiag<ExecSpace>::get_element(std::size_t i, std::size_t j)
const
55 return m_q.view_host()(j - i, i);
61template <
class ExecSpace>
62void SplinesLinearProblemPDSTridiag<ExecSpace>::set_element(
75 m_q.view_host()(j - i, i) = aij;
77 assert(std::fabs(aij) < 1e-15);
81template <
class ExecSpace>
82void SplinesLinearProblemPDSTridiag<ExecSpace>::setup_solver()
84 int const info = LAPACKE_dpttrf(
86 m_q.view_host().data(),
87 m_q.view_host().data() + m_q.view_host().stride(0));
89 throw std::runtime_error(
"LAPACKE_dpttrf failed with error code " + std::to_string(info));
97template <
class ExecSpace>
98void SplinesLinearProblemPDSTridiag<ExecSpace>::solve(MultiRHS
const b,
bool const)
const
100 assert(b.extent(0) == size());
101 auto q_device = m_q.view_device();
102 auto d = Kokkos::subview(q_device, 0, Kokkos::ALL);
103 auto e = Kokkos::subview(q_device, 1, Kokkos::pair<
int,
int>(0, q_device.extent_int(1) - 1));
104 Kokkos::RangePolicy<ExecSpace>
const policy(0, b.extent(1));
105 Kokkos::parallel_for(
108 KOKKOS_LAMBDA(
int const i) {
109 auto sub_b = Kokkos::subview(b, Kokkos::ALL, i);
110 KokkosBatched::SerialPttrs<
111 KokkosBatched::Uplo::Lower,
112 KokkosBatched::Algo::Pttrs::Unblocked>::invoke(d, e, sub_b);
116#if defined(KOKKOS_ENABLE_SERIAL
)
117template class SplinesLinearProblemPDSTridiag<Kokkos::Serial>;
119#if defined(KOKKOS_ENABLE_OPENMP)
120template class SplinesLinearProblemPDSTridiag<Kokkos::OpenMP>;
122#if defined(KOKKOS_ENABLE_CUDA)
123template class SplinesLinearProblemPDSTridiag<Kokkos::Cuda>;
125#if defined(KOKKOS_ENABLE_HIP)
126template class SplinesLinearProblemPDSTridiag<Kokkos::HIP>;
128#if defined(KOKKOS_ENABLE_SYCL)
129template class SplinesLinearProblemPDSTridiag<Kokkos::SYCL>;
The top-level namespace of DDC.