DDC 0.4.1
Loading...
Searching...
No Matches
splines_linear_problem_band.hpp
1// Copyright (C) The DDC development team, see COPYRIGHT.md file
2//
3// SPDX-License-Identifier: MIT
4
5#pragma once
6
7#include <algorithm>
8#include <cassert>
9#include <cmath>
10#include <cstddef>
11#include <stdexcept>
12#include <string>
13
14#include <Kokkos_Core.hpp>
15#include <Kokkos_DualView.hpp>
16
17#if __has_include(<mkl_lapacke.h>)
18#include <mkl_lapacke.h>
19#else
20#include <lapacke.h>
21#endif
22
23#include <KokkosBatched_Util.hpp>
24
25#include "kokkos-kernels-ext/KokkosBatched_Gbtrs.hpp"
26
27#include "splines_linear_problem.hpp"
28
29namespace ddc::detail {
30
31/**
32 * @brief A band linear problem dedicated to the computation of a spline approximation.
33 *
34 * The storage format is dense row-major. Lapack is used to perform every matrix and linear solver-related operations.
35 *
36 * Given the linear system A*x=b, we assume that A is a square (n by n)
37 * with ku superdiagonals and kl subdiagonals.
38 * All non-zero elements of A are stored in the rectangular matrix q, using
39 * the format required by DGBTRF (LAPACK): diagonals of A are rows of q.
40 * q has 2*kl rows for the subdiagonals, 1 row for the diagonal, and ku rows
41 * for the superdiagonals. (The kl additional rows are needed for pivoting.)
42 *
43 * @tparam ExecSpace The Kokkos::ExecutionSpace on which operations related to the matrix are supposed to be performed.
44 */
45template <class ExecSpace>
46class SplinesLinearProblemBand : public SplinesLinearProblem<ExecSpace>
47{
48public:
49 using typename SplinesLinearProblem<ExecSpace>::MultiRHS;
50 using SplinesLinearProblem<ExecSpace>::size;
51
52protected:
53 std::size_t m_kl; // no. of subdiagonals
54 std::size_t m_ku; // no. of superdiagonals
55 Kokkos::DualView<double**, Kokkos::LayoutRight, typename ExecSpace::memory_space>
56 m_q; // band matrix representation
57 Kokkos::DualView<int*, typename ExecSpace::memory_space> m_ipiv; // pivot indices
58
59public:
60 /**
61 * @brief SplinesLinearProblemBand constructor.
62 *
63 * @param mat_size The size of one of the dimensions of the square matrix.
64 * @param kl The number of subdiagonals of the matrix.
65 * @param ku The number of superdiagonals of the matrix.
66 */
67 explicit SplinesLinearProblemBand(
68 std::size_t const mat_size,
69 std::size_t const kl,
70 std::size_t const ku)
71 : SplinesLinearProblem<ExecSpace>(mat_size)
72 , m_kl(kl)
73 , m_ku(ku)
74 /*
75 * The matrix itself stored in band format requires a (kl + ku + 1)*mat_size
76 * allocation, but the LU-factorization requires an additional kl*mat_size block
77 */
78 , m_q("q", 2 * kl + ku + 1, mat_size)
79 , m_ipiv("ipiv", mat_size)
80 {
81 assert(m_kl <= mat_size);
82 assert(m_ku <= mat_size);
83
84 Kokkos::deep_copy(m_q.h_view, 0.);
85 }
86
87private:
88 std::size_t band_storage_row_index(std::size_t const i, std::size_t const j) const
89 {
90 return m_kl + m_ku + i - j;
91 }
92
93public:
94 double get_element(std::size_t const i, std::size_t const j) const override
95 {
96 assert(i < size());
97 assert(j < size());
98 /*
99 * The "row index" of the band format storage identify the (sub/super)-diagonal
100 * while the column index is actually the column index of the matrix. Two layouts
101 * are supported by LAPACKE. The m_kl first lines are irrelevant for the storage of
102 * the matrix itself but required for the storage of its LU factorization.
103 */
104 if (i >= std::
105 max(static_cast<std::ptrdiff_t>(0),
106 static_cast<std::ptrdiff_t>(j) - static_cast<std::ptrdiff_t>(m_ku))
107 && i < std::min(size(), j + m_kl + 1)) {
108 return m_q.h_view(band_storage_row_index(i, j), j);
109 }
110
111 return 0.0;
112 }
113
114 void set_element(std::size_t const i, std::size_t const j, double const aij) override
115 {
116 assert(i < size());
117 assert(j < size());
118 /*
119 * The "row index" of the band format storage identify the (sub/super)-diagonal
120 * while the column index is actually the column index of the matrix. Two layouts
121 * are supported by LAPACKE. The m_kl first lines are irrelevant for the storage of
122 * the matrix itself but required for the storage of its LU factorization.
123 */
124 if (i >= std::
125 max(static_cast<std::ptrdiff_t>(0),
126 static_cast<std::ptrdiff_t>(j) - static_cast<std::ptrdiff_t>(m_ku))
127 && i < std::min(size(), j + m_kl + 1)) {
128 m_q.h_view(band_storage_row_index(i, j), j) = aij;
129 } else {
130 assert(std::fabs(aij) < 1e-20);
131 }
132 }
133
134 /**
135 * @brief Perform a pre-process operation on the solver. Must be called after filling the matrix.
136 *
137 * LU-factorize the matrix A and store the pivots using the LAPACK dgbtrf() implementation.
138 */
139 void setup_solver() override
140 {
141 int const info = LAPACKE_dgbtrf(
142 LAPACK_ROW_MAJOR,
143 size(),
144 size(),
145 m_kl,
146 m_ku,
147 m_q.h_view.data(),
148 m_q.h_view.stride(
149 0), // m_q.h_view.stride(0) if LAPACK_ROW_MAJOR, m_q.h_view.stride(1) if LAPACK_COL_MAJOR
150 m_ipiv.h_view.data());
151 if (info != 0) {
152 throw std::runtime_error(
153 "LAPACKE_dgbtrf failed with error code " + std::to_string(info));
154 }
155
156 // Convert 1-based index to 0-based index
157 for (int i = 0; i < size(); ++i) {
158 m_ipiv.h_view(i) -= 1;
159 }
160
161 // Push on device
162 m_q.modify_host();
163 m_q.sync_device();
164 m_ipiv.modify_host();
165 m_ipiv.sync_device();
166 }
167
168 /**
169 * @brief Solve the multiple right-hand sides linear problem Ax=b or its transposed version A^tx=b inplace.
170 *
171 * The solver method is band gaussian elimination with partial pivoting using the LU-factorized matrix A. The implementation is LAPACK method dgbtrs.
172 *
173 * @param[in, out] b A 2D Kokkos::View storing the multiple right-hand sides of the problem and receiving the corresponding solution.
174 * @param transpose Choose between the direct or transposed version of the linear problem.
175 */
176 void solve(MultiRHS const b, bool const transpose) const override
177 {
178 assert(b.extent(0) == size());
179
180 std::size_t const kl_proxy = m_kl;
181 std::size_t const ku_proxy = m_ku;
182 auto q_device = m_q.d_view;
183 auto ipiv_device = m_ipiv.d_view;
184 Kokkos::RangePolicy<ExecSpace> const policy(0, b.extent(1));
185 if (transpose) {
186 Kokkos::parallel_for(
187 "gbtrs",
188 policy,
189 KOKKOS_LAMBDA(const int i) {
190 auto sub_b = Kokkos::subview(b, Kokkos::ALL, i);
191 KokkosBatched::SerialGbtrs<
192 KokkosBatched::Trans::Transpose,
193 KokkosBatched::Algo::Level3::Unblocked>::
194 invoke(q_device, sub_b, ipiv_device, kl_proxy, ku_proxy);
195 });
196 } else {
197 Kokkos::parallel_for(
198 "gbtrs",
199 policy,
200 KOKKOS_LAMBDA(const int i) {
201 auto sub_b = Kokkos::subview(b, Kokkos::ALL, i);
202 KokkosBatched::SerialGbtrs<
203 KokkosBatched::Trans::NoTranspose,
204 KokkosBatched::Algo::Level3::Unblocked>::
205 invoke(q_device, sub_b, ipiv_device, kl_proxy, ku_proxy);
206 });
207 }
208 }
209};
210
211} // namespace ddc::detail
The top-level namespace of DDC.