DDC 0.1.0
Loading...
Searching...
No Matches
splines_linear_problem_band.hpp
1// Copyright (C) The DDC development team, see COPYRIGHT.md file
2//
3// SPDX-License-Identifier: MIT
4
5#pragma once
6
7#include <algorithm>
8#include <cassert>
9#include <cmath>
10#include <cstddef>
11#include <stdexcept>
12#include <string>
13
14#include <Kokkos_Core.hpp>
15#include <Kokkos_DualView.hpp>
16
17#if __has_include(<mkl_lapacke.h>)
18#include <mkl_lapacke.h>
19#else
20#include <lapacke.h>
21#endif
22
23#include <KokkosBatched_Gbtrs.hpp>
24#include <KokkosBatched_Util.hpp>
25
26#include "splines_linear_problem.hpp"
27
28namespace ddc::detail {
29
30/**
31 * @brief A band linear problem dedicated to the computation of a spline approximation.
32 *
33 * The storage format is dense row-major. Lapack is used to perform every matrix and linear solver-related operations.
34 *
35 * Given the linear system A*x=b, we assume that A is a square (n by n)
36 * with ku superdiagonals and kl subdiagonals.
37 * All non-zero elements of A are stored in the rectangular matrix q, using
38 * the format required by DGBTRF (LAPACK): diagonals of A are rows of q.
39 * q has 2*kl rows for the subdiagonals, 1 row for the diagonal, and ku rows
40 * for the superdiagonals. (The kl additional rows are needed for pivoting.)
41 *
42 * @tparam ExecSpace The Kokkos::ExecutionSpace on which operations related to the matrix are supposed to be performed.
43 */
44template <class ExecSpace>
45class SplinesLinearProblemBand : public SplinesLinearProblem<ExecSpace>
46{
47public:
48 using typename SplinesLinearProblem<ExecSpace>::MultiRHS;
49 using SplinesLinearProblem<ExecSpace>::size;
50
51protected:
52 std::size_t m_kl; // no. of subdiagonals
53 std::size_t m_ku; // no. of superdiagonals
54 Kokkos::DualView<double**, Kokkos::LayoutRight, typename ExecSpace::memory_space>
55 m_q; // band matrix representation
56 Kokkos::DualView<int*, typename ExecSpace::memory_space> m_ipiv; // pivot indices
57
58public:
59 /**
60 * @brief SplinesLinearProblemBand constructor.
61 *
62 * @param mat_size The size of one of the dimensions of the square matrix.
63 * @param kl The number of subdiagonals of the matrix.
64 * @param ku The number of superdiagonals of the matrix.
65 */
66 explicit SplinesLinearProblemBand(
67 std::size_t const mat_size,
68 std::size_t const kl,
69 std::size_t const ku)
70 : SplinesLinearProblem<ExecSpace>(mat_size)
71 , m_kl(kl)
72 , m_ku(ku)
73 /*
74 * The matrix itself stored in band format requires a (kl + ku + 1)*mat_size
75 * allocation, but the LU-factorization requires an additional kl*mat_size block
76 */
77 , m_q("q", 2 * kl + ku + 1, mat_size)
78 , m_ipiv("ipiv", mat_size)
79 {
80 assert(m_kl <= mat_size);
81 assert(m_ku <= mat_size);
82
83 Kokkos::deep_copy(m_q.h_view, 0.);
84 }
85
86private:
87 std::size_t band_storage_row_index(std::size_t const i, std::size_t const j) const
88 {
89 return m_kl + m_ku + i - j;
90 }
91
92public:
93 double get_element(std::size_t const i, std::size_t const j) const override
94 {
95 assert(i < size());
96 assert(j < size());
97 /*
98 * The "row index" of the band format storage identify the (sub/super)-diagonal
99 * while the column index is actually the column index of the matrix. Two layouts
100 * are supported by LAPACKE. The m_kl first lines are irrelevant for the storage of
101 * the matrix itself but required for the storage of its LU factorization.
102 */
103 if (i >= std::
104 max(static_cast<std::ptrdiff_t>(0),
105 static_cast<std::ptrdiff_t>(j) - static_cast<std::ptrdiff_t>(m_ku))
106 && i < std::min(size(), j + m_kl + 1)) {
107 return m_q.h_view(band_storage_row_index(i, j), j);
108 }
109
110 return 0.0;
111 }
112
113 void set_element(std::size_t const i, std::size_t const j, double const aij) override
114 {
115 assert(i < size());
116 assert(j < size());
117 /*
118 * The "row index" of the band format storage identify the (sub/super)-diagonal
119 * while the column index is actually the column index of the matrix. Two layouts
120 * are supported by LAPACKE. The m_kl first lines are irrelevant for the storage of
121 * the matrix itself but required for the storage of its LU factorization.
122 */
123 if (i >= std::
124 max(static_cast<std::ptrdiff_t>(0),
125 static_cast<std::ptrdiff_t>(j) - static_cast<std::ptrdiff_t>(m_ku))
126 && i < std::min(size(), j + m_kl + 1)) {
127 m_q.h_view(band_storage_row_index(i, j), j) = aij;
128 } else {
129 assert(std::fabs(aij) < 1e-20);
130 }
131 }
132
133 /**
134 * @brief Perform a pre-process operation on the solver. Must be called after filling the matrix.
135 *
136 * LU-factorize the matrix A and store the pivots using the LAPACK dgbtrf() implementation.
137 */
138 void setup_solver() override
139 {
140 int const info = LAPACKE_dgbtrf(
141 LAPACK_ROW_MAJOR,
142 size(),
143 size(),
144 m_kl,
145 m_ku,
146 m_q.h_view.data(),
147 m_q.h_view.stride(
148 0), // m_q.h_view.stride(0) if LAPACK_ROW_MAJOR, m_q.h_view.stride(1) if LAPACK_COL_MAJOR
149 m_ipiv.h_view.data());
150 if (info != 0) {
151 throw std::runtime_error(
152 "LAPACKE_dgbtrf failed with error code " + std::to_string(info));
153 }
154
155 // Convert 1-based index to 0-based index
156 for (int i = 0; i < size(); ++i) {
157 m_ipiv.h_view(i) -= 1;
158 }
159
160 // Push on device
161 m_q.modify_host();
162 m_q.sync_device();
163 m_ipiv.modify_host();
164 m_ipiv.sync_device();
165 }
166
167 /**
168 * @brief Solve the multiple right-hand sides linear problem Ax=b or its transposed version A^tx=b inplace.
169 *
170 * The solver method is band gaussian elimination with partial pivoting using the LU-factorized matrix A. The implementation is LAPACK method dgbtrs.
171 *
172 * @param[in, out] b A 2D Kokkos::View storing the multiple right-hand sides of the problem and receiving the corresponding solution.
173 * @param transpose Choose between the direct or transposed version of the linear problem.
174 */
175 void solve(MultiRHS const b, bool const transpose) const override
176 {
177 assert(b.extent(0) == size());
178
179 std::size_t const kl_proxy = m_kl;
180 std::size_t const ku_proxy = m_ku;
181 auto q_device = m_q.d_view;
182 auto ipiv_device = m_ipiv.d_view;
183 Kokkos::RangePolicy<ExecSpace> const policy(0, b.extent(1));
184 if (transpose) {
185 Kokkos::parallel_for(
186 "gbtrs",
187 policy,
188 KOKKOS_LAMBDA(const int i) {
189 auto sub_b = Kokkos::subview(b, Kokkos::ALL, i);
190 KokkosBatched::SerialGbtrs<
191 KokkosBatched::Trans::Transpose,
192 KokkosBatched::Algo::Gbtrs::Unblocked>::
193 invoke(q_device, sub_b, ipiv_device, kl_proxy, ku_proxy);
194 });
195 } else {
196 Kokkos::parallel_for(
197 "gbtrs",
198 policy,
199 KOKKOS_LAMBDA(const int i) {
200 auto sub_b = Kokkos::subview(b, Kokkos::ALL, i);
201 KokkosBatched::SerialGbtrs<
202 KokkosBatched::Trans::NoTranspose,
203 KokkosBatched::Algo::Gbtrs::Unblocked>::
204 invoke(q_device, sub_b, ipiv_device, kl_proxy, ku_proxy);
205 });
206 }
207 }
208};
209
210} // namespace ddc::detail
The top-level namespace of DDC.