DDC 0.10.0
Loading...
Searching...
No Matches
splines_linear_problem_band.hpp
1// Copyright (C) The DDC development team, see COPYRIGHT.md file
2//
3// SPDX-License-Identifier: MIT
4
5#pragma once
6
7#include <algorithm>
8#include <cassert>
9#if !defined(NDEBUG)
10# include <cmath>
11#endif
12#include <cstddef>
13#include <stdexcept>
14#include <string>
15
16#include <Kokkos_Core.hpp>
17#include <Kokkos_DualView.hpp>
18
19#if __has_include(<mkl_lapacke.h>)
20# include <mkl_lapacke.h>
21#else
22# include <lapacke.h>
23#endif
24
25#include <KokkosBatched_Util.hpp>
26
27#include "kokkos-kernels-ext/KokkosBatched_Gbtrs.hpp"
28
29#include "splines_linear_problem.hpp"
30
31namespace ddc::detail {
32
33/**
34 * @brief A band linear problem dedicated to the computation of a spline approximation.
35 *
36 * The storage format is dense row-major. Lapack is used to perform every matrix and linear solver-related operations.
37 *
38 * Given the linear system A*x=b, we assume that A is a square (n by n)
39 * with ku superdiagonals and kl subdiagonals.
40 * All non-zero elements of A are stored in the rectangular matrix q, using
41 * the format required by DGBTRF (LAPACK): diagonals of A are rows of q.
42 * q has 2*kl rows for the subdiagonals, 1 row for the diagonal, and ku rows
43 * for the superdiagonals. (The kl additional rows are needed for pivoting.)
44 *
45 * @tparam ExecSpace The Kokkos::ExecutionSpace on which operations related to the matrix are supposed to be performed.
46 */
47template <class ExecSpace>
48class SplinesLinearProblemBand : public SplinesLinearProblem<ExecSpace>
49{
50public:
51 using typename SplinesLinearProblem<ExecSpace>::MultiRHS;
52 using SplinesLinearProblem<ExecSpace>::size;
53
54protected:
55 std::size_t m_kl; // no. of subdiagonals
56 std::size_t m_ku; // no. of superdiagonals
57 Kokkos::DualView<double**, Kokkos::LayoutRight, typename ExecSpace::memory_space>
58 m_q; // band matrix representation
59 Kokkos::DualView<int*, typename ExecSpace::memory_space> m_ipiv; // pivot indices
60
61public:
62 /**
63 * @brief SplinesLinearProblemBand constructor.
64 *
65 * @param mat_size The size of one of the dimensions of the square matrix.
66 * @param kl The number of subdiagonals of the matrix.
67 * @param ku The number of superdiagonals of the matrix.
68 */
69 explicit SplinesLinearProblemBand(
70 std::size_t const mat_size,
71 std::size_t const kl,
72 std::size_t const ku)
73 : SplinesLinearProblem<ExecSpace>(mat_size)
74 , m_kl(kl)
75 , m_ku(ku)
76 /*
77 * The matrix itself stored in band format requires a (kl + ku + 1)*mat_size
78 * allocation, but the LU-factorization requires an additional kl*mat_size block
79 */
80 , m_q("q", 2 * kl + ku + 1, mat_size)
81 , m_ipiv("ipiv", mat_size)
82 {
83 assert(m_kl <= mat_size);
84 assert(m_ku <= mat_size);
85
86 Kokkos::deep_copy(m_q.view_host(), 0.);
87 }
88
89private:
90 std::size_t band_storage_row_index(std::size_t const i, std::size_t const j) const
91 {
92 return m_kl + m_ku + i - j;
93 }
94
95public:
96 double get_element(std::size_t const i, std::size_t const j) const override
97 {
98 assert(i < size());
99 assert(j < size());
100 /*
101 * The "row index" of the band format storage identify the (sub/super)-diagonal
102 * while the column index is actually the column index of the matrix. Two layouts
103 * are supported by LAPACKE. The m_kl first lines are irrelevant for the storage of
104 * the matrix itself but required for the storage of its LU factorization.
105 */
106 if (i >= std::
107 max(static_cast<std::ptrdiff_t>(0),
108 static_cast<std::ptrdiff_t>(j) - static_cast<std::ptrdiff_t>(m_ku))
109 && i < std::min(size(), j + m_kl + 1)) {
110 return m_q.view_host()(band_storage_row_index(i, j), j);
111 }
112
113 return 0.0;
114 }
115
116 void set_element(std::size_t const i, std::size_t const j, double const aij) override
117 {
118 assert(i < size());
119 assert(j < size());
120 /*
121 * The "row index" of the band format storage identify the (sub/super)-diagonal
122 * while the column index is actually the column index of the matrix. Two layouts
123 * are supported by LAPACKE. The m_kl first lines are irrelevant for the storage of
124 * the matrix itself but required for the storage of its LU factorization.
125 */
126 if (i >= std::
127 max(static_cast<std::ptrdiff_t>(0),
128 static_cast<std::ptrdiff_t>(j) - static_cast<std::ptrdiff_t>(m_ku))
129 && i < std::min(size(), j + m_kl + 1)) {
130 m_q.view_host()(band_storage_row_index(i, j), j) = aij;
131 } else {
132 assert(std::fabs(aij) < 1e-20);
133 }
134 }
135
136 /**
137 * @brief Perform a pre-process operation on the solver. Must be called after filling the matrix.
138 *
139 * LU-factorize the matrix A and store the pivots using the LAPACK dgbtrf() implementation.
140 */
141 void setup_solver() override
142 {
143 int const info = LAPACKE_dgbtrf(
144 LAPACK_ROW_MAJOR,
145 size(),
146 size(),
147 m_kl,
148 m_ku,
149 m_q.view_host().data(),
150 m_q.view_host().stride(
151 0), // m_q.view_host().stride(0) if LAPACK_ROW_MAJOR, m_q.view_host().stride(1) if LAPACK_COL_MAJOR
152 m_ipiv.view_host().data());
153 if (info != 0) {
154 throw std::runtime_error(
155 "LAPACKE_dgbtrf failed with error code " + std::to_string(info));
156 }
157
158 // Convert 1-based index to 0-based index
159 for (std::size_t i = 0; i < size(); ++i) {
160 m_ipiv.view_host()(i) -= 1;
161 }
162
163 // Push on device
164 m_q.modify_host();
165 m_q.sync_device();
166 m_ipiv.modify_host();
167 m_ipiv.sync_device();
168 }
169
170 /**
171 * @brief Solve the multiple right-hand sides linear problem Ax=b or its transposed version A^tx=b inplace.
172 *
173 * The solver method is band gaussian elimination with partial pivoting using the LU-factorized matrix A. The implementation is LAPACK method dgbtrs.
174 *
175 * @param[in, out] b A 2D Kokkos::View storing the multiple right-hand sides of the problem and receiving the corresponding solution.
176 * @param transpose Choose between the direct or transposed version of the linear problem.
177 */
178 void solve(MultiRHS const b, bool const transpose) const override
179 {
180 assert(b.extent(0) == size());
181
182 std::size_t const kl_proxy = m_kl;
183 std::size_t const ku_proxy = m_ku;
184 auto q_device = m_q.view_device();
185 auto ipiv_device = m_ipiv.view_device();
186 Kokkos::RangePolicy<ExecSpace> const policy(0, b.extent(1));
187 if (transpose) {
188 Kokkos::parallel_for(
189 "gbtrs",
190 policy,
191 KOKKOS_LAMBDA(int const i) {
192 auto sub_b = Kokkos::subview(b, Kokkos::ALL, i);
193 KokkosBatched::SerialGbtrs<
194 KokkosBatched::Trans::Transpose,
195 KokkosBatched::Algo::Level3::Unblocked>::
196 invoke(q_device, sub_b, ipiv_device, kl_proxy, ku_proxy);
197 });
198 } else {
199 Kokkos::parallel_for(
200 "gbtrs",
201 policy,
202 KOKKOS_LAMBDA(int const i) {
203 auto sub_b = Kokkos::subview(b, Kokkos::ALL, i);
204 KokkosBatched::SerialGbtrs<
205 KokkosBatched::Trans::NoTranspose,
206 KokkosBatched::Algo::Level3::Unblocked>::
207 invoke(q_device, sub_b, ipiv_device, kl_proxy, ku_proxy);
208 });
209 }
210 }
211};
212
213} // namespace ddc::detail
The top-level namespace of DDC.