DDC 0.10.0
Loading...
Searching...
No Matches
splines_linear_problem_pds_band.cpp
1// Copyright (C) The DDC development team, see COPYRIGHT.md file
2//
3// SPDX-License-Identifier: MIT
4
5#include <cassert>
6#if !defined(NDEBUG)
7# include <cmath>
8#endif
9#include <cstddef>
10#include <stdexcept>
11#include <string>
12#include <utility>
13
14#include <Kokkos_Core.hpp>
15#include <Kokkos_DualView.hpp>
16
17#if __has_include(<mkl_lapacke.h>)
18# include <mkl_lapacke.h>
19#else
20# include <lapacke.h>
21#endif
22
23#include <KokkosBatched_Pbtrs.hpp>
24#include <KokkosBatched_Util.hpp>
25
28
29namespace ddc::detail {
30
31template <class ExecSpace>
32SplinesLinearProblemPDSBand<ExecSpace>::SplinesLinearProblemPDSBand(
33 std::size_t const mat_size,
34 std::size_t const kd)
35 : SplinesLinearProblem<ExecSpace>(mat_size)
36 , m_q("q", kd + 1, mat_size)
37{
38 assert(m_q.extent(0) <= mat_size);
39
40 Kokkos::deep_copy(m_q.view_host(), 0.);
41}
42
43template <class ExecSpace>
44SplinesLinearProblemPDSBand<ExecSpace>::~SplinesLinearProblemPDSBand() = default;
45
46template <class ExecSpace>
47double SplinesLinearProblemPDSBand<ExecSpace>::get_element(std::size_t i, std::size_t j) const
48{
49 assert(i < size());
50 assert(j < size());
51
52 // Indices are swapped for an element on subdiagonal
53 if (i > j) {
54 std::swap(i, j);
55 }
56
57 if (j - i < m_q.extent(0)) {
58 return m_q.view_host()(j - i, i);
59 }
60
61 return 0.0;
62}
63
64template <class ExecSpace>
65void SplinesLinearProblemPDSBand<ExecSpace>::set_element(
66 std::size_t i,
67 std::size_t j,
68 double const aij)
69{
70 assert(i < size());
71 assert(j < size());
72
73 // Indices are swapped for an element on subdiagonal
74 if (i > j) {
75 std::swap(i, j);
76 }
77 if (j - i < m_q.extent(0)) {
78 m_q.view_host()(j - i, i) = aij;
79 } else {
80 assert(std::fabs(aij) < 1e-15);
81 }
82}
83
84template <class ExecSpace>
85void SplinesLinearProblemPDSBand<ExecSpace>::setup_solver()
86{
87 int const info = LAPACKE_dpbtrf(
88 LAPACK_ROW_MAJOR,
89 'L',
90 size(),
91 m_q.extent(0) - 1,
92 m_q.view_host().data(),
93 m_q.view_host().stride(
94 0) // m_q.view_host().stride(0) if LAPACK_ROW_MAJOR, m_q.view_host().stride(1) if LAPACK_COL_MAJOR
95 );
96 if (info != 0) {
97 throw std::runtime_error("LAPACKE_dpbtrf failed with error code " + std::to_string(info));
98 }
99
100 // Push on device
101 m_q.modify_host();
102 m_q.sync_device();
103}
104
105template <class ExecSpace>
106void SplinesLinearProblemPDSBand<ExecSpace>::solve(MultiRHS const b, bool const) const
107{
108 assert(b.extent(0) == size());
109
110 auto q_device = m_q.view_device();
111 Kokkos::RangePolicy<ExecSpace> const policy(0, b.extent(1));
112 Kokkos::parallel_for(
113 "pbtrs",
114 policy,
115 KOKKOS_LAMBDA(int const i) {
116 auto sub_b = Kokkos::subview(b, Kokkos::ALL, i);
117 KokkosBatched::SerialPbtrs<
118 KokkosBatched::Uplo::Lower,
119 KokkosBatched::Algo::Pbtrs::Unblocked>::invoke(q_device, sub_b);
120 });
121}
122
123#if defined(KOKKOS_ENABLE_SERIAL)
124template class SplinesLinearProblemPDSBand<Kokkos::Serial>;
125#endif
126#if defined(KOKKOS_ENABLE_OPENMP)
127template class SplinesLinearProblemPDSBand<Kokkos::OpenMP>;
128#endif
129#if defined(KOKKOS_ENABLE_CUDA)
130template class SplinesLinearProblemPDSBand<Kokkos::Cuda>;
131#endif
132#if defined(KOKKOS_ENABLE_HIP)
133template class SplinesLinearProblemPDSBand<Kokkos::HIP>;
134#endif
135#if defined(KOKKOS_ENABLE_SYCL)
136template class SplinesLinearProblemPDSBand<Kokkos::SYCL>;
137#endif
138
139} // namespace ddc::detail
The top-level namespace of DDC.