DDC 0.11.0
Loading...
Searching...
No Matches
splines_linear_problem_pds_tridiag.cpp
1// Copyright (C) The DDC development team, see COPYRIGHT.md file
2//
3// SPDX-License-Identifier: MIT
4
5#include <cassert>
6#if !defined(NDEBUG)
7# include <cmath>
8#endif
9#include <cstddef>
10#include <stdexcept>
11#include <string>
12#include <utility>
13
14#include <Kokkos_Core.hpp>
15
16#if __has_include(<mkl_lapacke.h>)
17# include <mkl_lapacke.h>
18#else
19# include <lapacke.h>
20#endif
21
22#include <KokkosBatched_Pttrs.hpp>
23#include <KokkosBatched_Util.hpp>
24
27
28namespace ddc::detail {
29
30template <class ExecSpace>
31SplinesLinearProblemPDSTridiag<ExecSpace>::SplinesLinearProblemPDSTridiag(
32 std::size_t const mat_size)
33 : SplinesLinearProblem<ExecSpace>(mat_size)
34 , m_q("q", 2, mat_size)
35{
36 Kokkos::deep_copy(m_q.view_host(), 0.);
37}
38
39template <class ExecSpace>
40SplinesLinearProblemPDSTridiag<ExecSpace>::~SplinesLinearProblemPDSTridiag() = default;
41
42template <class ExecSpace>
43double SplinesLinearProblemPDSTridiag<ExecSpace>::get_element(std::size_t i, std::size_t j) const
44{
45 assert(i < size());
46 assert(j < size());
47
48 // Indices are swapped for an element on subdiagonal
49 if (i > j) {
50 std::swap(i, j);
51 }
52
53 if (j - i < 2) {
54 return m_q.view_host()(j - i, i);
55 }
56
57 return 0.0;
58}
59
60template <class ExecSpace>
61void SplinesLinearProblemPDSTridiag<ExecSpace>::set_element(
62 std::size_t i,
63 std::size_t j,
64 double const aij)
65{
66 assert(i < size());
67 assert(j < size());
68
69 // Indices are swapped for an element on subdiagonal
70 if (i > j) {
71 std::swap(i, j);
72 }
73 if (j - i < 2) {
74 m_q.view_host()(j - i, i) = aij;
75 } else {
76 assert(std::fabs(aij) < 1e-15);
77 }
78}
79
80template <class ExecSpace>
81void SplinesLinearProblemPDSTridiag<ExecSpace>::setup_solver()
82{
83 int const info = LAPACKE_dpttrf(
84 size(),
85 m_q.view_host().data(),
86 m_q.view_host().data() + m_q.view_host().stride(0));
87 if (info != 0) {
88 throw std::runtime_error("LAPACKE_dpttrf failed with error code " + std::to_string(info));
89 }
90
91 // Push on device
92 m_q.modify_host();
93 m_q.sync_device();
94}
95
96template <class ExecSpace>
97void SplinesLinearProblemPDSTridiag<ExecSpace>::solve(MultiRHS const b, bool const) const
98{
99 assert(b.extent(0) == size());
100 auto q_device = m_q.view_device();
101 auto d = Kokkos::subview(q_device, 0, Kokkos::ALL);
102 auto e = Kokkos::subview(q_device, 1, Kokkos::pair<int, int>(0, q_device.extent_int(1) - 1));
103 Kokkos::RangePolicy<ExecSpace> const policy(0, b.extent(1));
104 Kokkos::parallel_for(
105 "pttrs",
106 policy,
107 KOKKOS_LAMBDA(int const i) {
108 auto sub_b = Kokkos::subview(b, Kokkos::ALL, i);
109 KokkosBatched::SerialPttrs<
110 KokkosBatched::Uplo::Lower,
111 KokkosBatched::Algo::Pttrs::Unblocked>::invoke(d, e, sub_b);
112 });
113}
114
115#if defined(KOKKOS_ENABLE_SERIAL)
116template class SplinesLinearProblemPDSTridiag<Kokkos::Serial>;
117#endif
118#if defined(KOKKOS_ENABLE_OPENMP)
119template class SplinesLinearProblemPDSTridiag<Kokkos::OpenMP>;
120#endif
121#if defined(KOKKOS_ENABLE_CUDA)
122template class SplinesLinearProblemPDSTridiag<Kokkos::Cuda>;
123#endif
124#if defined(KOKKOS_ENABLE_HIP)
125template class SplinesLinearProblemPDSTridiag<Kokkos::HIP>;
126#endif
127#if defined(KOKKOS_ENABLE_SYCL)
128template class SplinesLinearProblemPDSTridiag<Kokkos::SYCL>;
129#endif
130
131} // namespace ddc::detail
The top-level namespace of DDC.