DDC 0.4.1
Loading...
Searching...
No Matches
splines_linear_problem_pds_band.hpp
1// Copyright (C) The DDC development team, see COPYRIGHT.md file
2//
3// SPDX-License-Identifier: MIT
4
5#pragma once
6
7#include <cassert>
8#include <cstddef>
9#include <stdexcept>
10#include <string>
11#include <utility>
12
13#include <Kokkos_Core.hpp>
14#include <Kokkos_DualView.hpp>
15
16#if __has_include(<mkl_lapacke.h>)
17#include <mkl_lapacke.h>
18#else
19#include <lapacke.h>
20#endif
21
22#include <KokkosBatched_Pbtrs.hpp>
23#include <KokkosBatched_Util.hpp>
24
25#include "splines_linear_problem.hpp"
26
27namespace ddc::detail {
28
29/**
30 * @brief A positive-definite symmetric band linear problem dedicated to the computation of a spline approximation.
31 *
32 * The storage format is dense row-major. Lapack is used to perform every matrix and linear solver-related operations.
33 *
34 * Given the linear system A*x=b, we assume that A is a square (n by n)
35 * with kd sub and superdiagonals.
36 * All non-zero elements of A are stored in the rectangular matrix q, using
37 * the format required by DPBTRF (LAPACK): (super-)diagonals of A are rows of q.
38 * q has 1 row for the diagonal and kd rows for the superdiagonals.
39 *
40 * @tparam ExecSpace The Kokkos::ExecutionSpace on which operations related to the matrix are supposed to be performed.
41 */
42template <class ExecSpace>
43class SplinesLinearProblemPDSBand : public SplinesLinearProblem<ExecSpace>
44{
45public:
46 using typename SplinesLinearProblem<ExecSpace>::MultiRHS;
47 using SplinesLinearProblem<ExecSpace>::size;
48
49protected:
50 Kokkos::DualView<double**, Kokkos::LayoutRight, typename ExecSpace::memory_space>
51 m_q; // pds band matrix representation
52
53public:
54 /**
55 * @brief SplinesLinearProblemPDSBand constructor.
56 *
57 * @param mat_size The size of one of the dimensions of the square matrix.
58 * @param kd The number of sub/superdiagonals of the matrix.
59 */
60 explicit SplinesLinearProblemPDSBand(std::size_t const mat_size, std::size_t const kd)
61 : SplinesLinearProblem<ExecSpace>(mat_size)
62 , m_q("q", kd + 1, mat_size)
63 {
64 assert(m_q.extent(0) <= mat_size);
65
66 Kokkos::deep_copy(m_q.h_view, 0.);
67 }
68
69 double get_element(std::size_t i, std::size_t j) const override
70 {
71 assert(i < size());
72 assert(j < size());
73
74 // Indices are swapped for an element on subdiagonal
75 if (i > j) {
76 std::swap(i, j);
77 }
78
79 if (j - i < m_q.extent(0)) {
80 return m_q.h_view(j - i, i);
81 }
82
83 return 0.0;
84 }
85
86 void set_element(std::size_t i, std::size_t j, double const aij) override
87 {
88 assert(i < size());
89 assert(j < size());
90
91 // Indices are swapped for an element on subdiagonal
92 if (i > j) {
93 std::swap(i, j);
94 }
95 if (j - i < m_q.extent(0)) {
96 m_q.h_view(j - i, i) = aij;
97 } else {
98 assert(std::fabs(aij) < 1e-20);
99 }
100 }
101
102 /**
103 * @brief Perform a pre-process operation on the solver. Must be called after filling the matrix.
104 *
105 * LU-factorize the matrix A and store the pivots using the LAPACK dpbtrf() implementation.
106 */
107 void setup_solver() override
108 {
109 int const info = LAPACKE_dpbtrf(
110 LAPACK_ROW_MAJOR,
111 'L',
112 size(),
113 m_q.extent(0) - 1,
114 m_q.h_view.data(),
115 m_q.h_view.stride(
116 0) // m_q.h_view.stride(0) if LAPACK_ROW_MAJOR, m_q.h_view.stride(1) if LAPACK_COL_MAJOR
117 );
118 if (info != 0) {
119 throw std::runtime_error(
120 "LAPACKE_dpbtrf failed with error code " + std::to_string(info));
121 }
122
123 // Push on device
124 m_q.modify_host();
125 m_q.sync_device();
126 }
127
128 /**
129 * @brief Solve the multiple right-hand sides linear problem Ax=b or its transposed version A^tx=b inplace.
130 *
131 * The solver method is band gaussian elimination with partial pivoting using the LU-factorized matrix A. The implementation is LAPACK method dpbtrs.
132 *
133 * @param[in, out] b A 2D Kokkos::View storing the multiple right-hand sides of the problem and receiving the corresponding solution.
134 * @param transpose Choose between the direct or transposed version of the linear problem (unused for a symmetric problem).
135 */
136 void solve(MultiRHS const b, bool const) const override
137 {
138 assert(b.extent(0) == size());
139
140 auto q_device = m_q.d_view;
141 Kokkos::RangePolicy<ExecSpace> const policy(0, b.extent(1));
142 Kokkos::parallel_for(
143 "pbtrs",
144 policy,
145 KOKKOS_LAMBDA(const int i) {
146 auto sub_b = Kokkos::subview(b, Kokkos::ALL, i);
147 KokkosBatched::SerialPbtrs<
148 KokkosBatched::Uplo::Lower,
149 KokkosBatched::Algo::Pbtrs::Unblocked>::invoke(q_device, sub_b);
150 });
151 }
152};
153
154} // namespace ddc::detail
The top-level namespace of DDC.