DDC 0.11.0
Loading...
Searching...
No Matches
splines_linear_problem_pds_band.cpp
1// Copyright (C) The DDC development team, see COPYRIGHT.md file
2//
3// SPDX-License-Identifier: MIT
4
5#include <cassert>
6#if !defined(NDEBUG)
7# include <cmath>
8#endif
9#include <cstddef>
10#include <stdexcept>
11#include <string>
12#include <utility>
13
14#include <Kokkos_Core.hpp>
15
16#if __has_include(<mkl_lapacke.h>)
17# include <mkl_lapacke.h>
18#else
19# include <lapacke.h>
20#endif
21
22#include <KokkosBatched_Pbtrs.hpp>
23#include <KokkosBatched_Util.hpp>
24
27
28namespace ddc::detail {
29
30template <class ExecSpace>
31SplinesLinearProblemPDSBand<ExecSpace>::SplinesLinearProblemPDSBand(
32 std::size_t const mat_size,
33 std::size_t const kd)
34 : SplinesLinearProblem<ExecSpace>(mat_size)
35 , m_q("q", kd + 1, mat_size)
36{
37 assert(m_q.extent(0) <= mat_size);
38
39 Kokkos::deep_copy(m_q.view_host(), 0.);
40}
41
42template <class ExecSpace>
43SplinesLinearProblemPDSBand<ExecSpace>::~SplinesLinearProblemPDSBand() = default;
44
45template <class ExecSpace>
46double SplinesLinearProblemPDSBand<ExecSpace>::get_element(std::size_t i, std::size_t j) const
47{
48 assert(i < size());
49 assert(j < size());
50
51 // Indices are swapped for an element on subdiagonal
52 if (i > j) {
53 std::swap(i, j);
54 }
55
56 if (j - i < m_q.extent(0)) {
57 return m_q.view_host()(j - i, i);
58 }
59
60 return 0.0;
61}
62
63template <class ExecSpace>
64void SplinesLinearProblemPDSBand<ExecSpace>::set_element(
65 std::size_t i,
66 std::size_t j,
67 double const aij)
68{
69 assert(i < size());
70 assert(j < size());
71
72 // Indices are swapped for an element on subdiagonal
73 if (i > j) {
74 std::swap(i, j);
75 }
76 if (j - i < m_q.extent(0)) {
77 m_q.view_host()(j - i, i) = aij;
78 } else {
79 assert(std::fabs(aij) < 1e-15);
80 }
81}
82
83template <class ExecSpace>
84void SplinesLinearProblemPDSBand<ExecSpace>::setup_solver()
85{
86 int const info = LAPACKE_dpbtrf(
87 LAPACK_ROW_MAJOR,
88 'L',
89 size(),
90 m_q.extent(0) - 1,
91 m_q.view_host().data(),
92 m_q.view_host().stride(
93 0) // m_q.view_host().stride(0) if LAPACK_ROW_MAJOR, m_q.view_host().stride(1) if LAPACK_COL_MAJOR
94 );
95 if (info != 0) {
96 throw std::runtime_error("LAPACKE_dpbtrf failed with error code " + std::to_string(info));
97 }
98
99 // Push on device
100 m_q.modify_host();
101 m_q.sync_device();
102}
103
104template <class ExecSpace>
105void SplinesLinearProblemPDSBand<ExecSpace>::solve(MultiRHS const b, bool const) const
106{
107 assert(b.extent(0) == size());
108
109 auto q_device = m_q.view_device();
110 Kokkos::RangePolicy<ExecSpace> const policy(0, b.extent(1));
111 Kokkos::parallel_for(
112 "pbtrs",
113 policy,
114 KOKKOS_LAMBDA(int const i) {
115 auto sub_b = Kokkos::subview(b, Kokkos::ALL, i);
116 KokkosBatched::SerialPbtrs<
117 KokkosBatched::Uplo::Lower,
118 KokkosBatched::Algo::Pbtrs::Unblocked>::invoke(q_device, sub_b);
119 });
120}
121
122#if defined(KOKKOS_ENABLE_SERIAL)
123template class SplinesLinearProblemPDSBand<Kokkos::Serial>;
124#endif
125#if defined(KOKKOS_ENABLE_OPENMP)
126template class SplinesLinearProblemPDSBand<Kokkos::OpenMP>;
127#endif
128#if defined(KOKKOS_ENABLE_CUDA)
129template class SplinesLinearProblemPDSBand<Kokkos::Cuda>;
130#endif
131#if defined(KOKKOS_ENABLE_HIP)
132template class SplinesLinearProblemPDSBand<Kokkos::HIP>;
133#endif
134#if defined(KOKKOS_ENABLE_SYCL)
135template class SplinesLinearProblemPDSBand<Kokkos::SYCL>;
136#endif
137
138} // namespace ddc::detail
The top-level namespace of DDC.