14#include <Kokkos_Core.hpp>
15#include <Kokkos_DualView.hpp>
17#if __has_include(<mkl_lapacke.h>)
18#include <mkl_lapacke.h>
23#include <KokkosBatched_Pttrs.hpp>
24#include <KokkosBatched_Util.hpp>
26#include "splines_linear_problem.hpp"
28namespace ddc::detail {
31
32
33
34
35
36
37
38
39
40
41
42
43template <
class ExecSpace>
44class SplinesLinearProblemPDSTridiag :
public SplinesLinearProblem<ExecSpace>
47 using typename SplinesLinearProblem<ExecSpace>::MultiRHS;
48 using SplinesLinearProblem<ExecSpace>::size;
51 Kokkos::DualView<
double**, Kokkos::LayoutRight,
typename ExecSpace::memory_space>
56
57
58
59
60 explicit SplinesLinearProblemPDSTridiag(std::size_t
const mat_size)
61 : SplinesLinearProblem<ExecSpace>(mat_size)
62 , m_q(
"q", 2, mat_size)
64 Kokkos::deep_copy(m_q.h_view, 0.);
67 double get_element(std::size_t i, std::size_t j)
const override
78 return m_q.h_view(j - i, i);
84 void set_element(std::size_t i, std::size_t j,
double const aij)
override
94 m_q.h_view(j - i, i) = aij;
96 assert(std::fabs(aij) < 1e-20);
101
102
103
104
105 void setup_solver()
override
107 int const info = LAPACKE_dpttrf(
110 m_q.h_view.data() + m_q.h_view.stride(0));
112 throw std::runtime_error(
113 "LAPACKE_dpttrf failed with error code " + std::to_string(info));
122
123
124
125
126
127
128
129 void solve(MultiRHS
const b,
bool const)
const override
131 assert(b.extent(0) == size());
132 auto q_device = m_q.d_view;
133 auto d = Kokkos::subview(q_device, 0, Kokkos::ALL);
135 subview(q_device, 1, Kokkos::pair<
int,
int>(0, q_device.extent_int(1) - 1));
136 Kokkos::RangePolicy<ExecSpace>
const policy(0, b.extent(1));
137 Kokkos::parallel_for(
140 KOKKOS_LAMBDA(
const int i) {
141 auto sub_b = Kokkos::subview(b, Kokkos::ALL, i);
142 KokkosBatched::SerialPttrs<
143 KokkosBatched::Uplo::Lower,
144 KokkosBatched::Algo::Pttrs::Unblocked>::invoke(d, e, sub_b);
The top-level namespace of DDC.