12#include <Kokkos_Core.hpp>
13#include <Kokkos_DualView.hpp>
15#if __has_include(<mkl_lapacke.h>)
16#include <mkl_lapacke.h>
21#include <KokkosBatched_Getrs.hpp>
22#include <KokkosBatched_Util.hpp>
24#include "splines_linear_problem.hpp"
26namespace ddc::detail {
29
30
31
32
33
34
35template <
class ExecSpace>
36class SplinesLinearProblemDense :
public SplinesLinearProblem<ExecSpace>
39 using typename SplinesLinearProblem<ExecSpace>::MultiRHS;
40 using SplinesLinearProblem<ExecSpace>::size;
43 Kokkos::DualView<
double**, Kokkos::LayoutRight,
typename ExecSpace::memory_space> m_a;
44 Kokkos::DualView<
int*,
typename ExecSpace::memory_space> m_ipiv;
48
49
50
51
52 explicit SplinesLinearProblemDense(std::size_t
const mat_size)
53 : SplinesLinearProblem<ExecSpace>(mat_size)
54 , m_a(
"a", mat_size, mat_size)
55 , m_ipiv(
"ipiv", mat_size)
57 Kokkos::deep_copy(m_a.h_view, 0.);
60 double get_element(std::size_t
const i, std::size_t
const j)
const override
64 return m_a.h_view(i, j);
67 void set_element(std::size_t
const i, std::size_t
const j,
double const aij)
override
71 m_a.h_view(i, j) = aij;
75
76
77
78
79 void setup_solver()
override
81 int const info = LAPACKE_dgetrf(
87 m_ipiv.h_view.data());
89 throw std::runtime_error(
90 "LAPACKE_dgetrf failed with error code " + std::to_string(info));
94 for (
int i = 0; i < size(); ++i) {
95 m_ipiv.h_view(i) -= 1;
101 m_ipiv.modify_host();
102 m_ipiv.sync_device();
106
107
108
109
110
111
112
113 void solve(MultiRHS
const b,
bool const transpose)
const override
115 assert(b.extent(0) == size());
122 auto a_device = m_a.d_view;
123 auto ipiv_device = m_ipiv.d_view;
125 Kokkos::RangePolicy<ExecSpace>
const policy(0, b.extent(1));
128 Kokkos::parallel_for(
131 KOKKOS_LAMBDA(
const int i) {
132 auto sub_b = Kokkos::subview(b, Kokkos::ALL, i);
133 KokkosBatched::SerialGetrs<
134 KokkosBatched::Trans::Transpose,
135 KokkosBatched::Algo::Getrs::Unblocked>::
136 invoke(a_device, ipiv_device, sub_b);
139 Kokkos::parallel_for(
142 KOKKOS_LAMBDA(
const int i) {
143 auto sub_b = Kokkos::subview(b, Kokkos::ALL, i);
144 KokkosBatched::SerialGetrs<
145 KokkosBatched::Trans::NoTranspose,
146 KokkosBatched::Algo::Getrs::Unblocked>::
147 invoke(a_device, ipiv_device, sub_b);
The top-level namespace of DDC.