DDC 0.4.1
Loading...
Searching...
No Matches
splines_linear_problem_pds_tridiag.hpp
1// Copyright (C) The DDC development team, see COPYRIGHT.md file
2//
3// SPDX-License-Identifier: MIT
4
5#pragma once
6
7#include <cassert>
8#include <cmath>
9#include <cstddef>
10#include <stdexcept>
11#include <string>
12#include <utility>
13
14#include <Kokkos_Core.hpp>
15#include <Kokkos_DualView.hpp>
16
17#if __has_include(<mkl_lapacke.h>)
18#include <mkl_lapacke.h>
19#else
20#include <lapacke.h>
21#endif
22
23#include <KokkosBatched_Pttrs.hpp>
24#include <KokkosBatched_Util.hpp>
25
26#include "splines_linear_problem.hpp"
27
28namespace ddc::detail {
29
30/**
31 * @brief A positive-definite symmetric tridiagonal linear problem dedicated to the computation of a spline approximation.
32 *
33 * The storage format is dense row-major. Lapack is used to perform every matrix and linear solver-related operations.
34 *
35 * Given the linear system A*x=b, we assume that A is a square (n by n)
36 * with 1 subdiagonal and 1 superdiagonal.
37 * All non-zero elements of A are stored in the rectangular matrix q, using
38 * the format required by DPTTRF (LAPACK): diagonal and superdiagonal of A are rows of q.
39 * q has 1 row for the diagonal and 1 row for the superdiagonal.
40 *
41 * @tparam ExecSpace The Kokkos::ExecutionSpace on which operations related to the matrix are supposed to be performed.
42 */
43template <class ExecSpace>
44class SplinesLinearProblemPDSTridiag : public SplinesLinearProblem<ExecSpace>
45{
46public:
47 using typename SplinesLinearProblem<ExecSpace>::MultiRHS;
48 using SplinesLinearProblem<ExecSpace>::size;
49
50protected:
51 Kokkos::DualView<double**, Kokkos::LayoutRight, typename ExecSpace::memory_space>
52 m_q; // pds tridiagonal matrix representation
53
54public:
55 /**
56 * @brief SplinesLinearProblemPDSTridiag constructor.
57 *
58 * @param mat_size The size of one of the dimensions of the square matrix.
59 */
60 explicit SplinesLinearProblemPDSTridiag(std::size_t const mat_size)
61 : SplinesLinearProblem<ExecSpace>(mat_size)
62 , m_q("q", 2, mat_size)
63 {
64 Kokkos::deep_copy(m_q.h_view, 0.);
65 }
66
67 double get_element(std::size_t i, std::size_t j) const override
68 {
69 assert(i < size());
70 assert(j < size());
71
72 // Indices are swapped for an element on subdiagonal
73 if (i > j) {
74 std::swap(i, j);
75 }
76
77 if (j - i < 2) {
78 return m_q.h_view(j - i, i);
79 }
80
81 return 0.0;
82 }
83
84 void set_element(std::size_t i, std::size_t j, double const aij) override
85 {
86 assert(i < size());
87 assert(j < size());
88
89 // Indices are swapped for an element on subdiagonal
90 if (i > j) {
91 std::swap(i, j);
92 }
93 if (j - i < 2) {
94 m_q.h_view(j - i, i) = aij;
95 } else {
96 assert(std::fabs(aij) < 1e-20);
97 }
98 }
99
100 /**
101 * @brief Perform a pre-process operation on the solver. Must be called after filling the matrix.
102 *
103 * LU-factorize the matrix A and store the pivots using the LAPACK dpttrf() implementation.
104 */
105 void setup_solver() override
106 {
107 int const info = LAPACKE_dpttrf(
108 size(),
109 m_q.h_view.data(),
110 m_q.h_view.data() + m_q.h_view.stride(0));
111 if (info != 0) {
112 throw std::runtime_error(
113 "LAPACKE_dpttrf failed with error code " + std::to_string(info));
114 }
115
116 // Push on device
117 m_q.modify_host();
118 m_q.sync_device();
119 }
120
121 /**
122 * @brief Solve the multiple right-hand sides linear problem Ax=b or its transposed version A^tx=b inplace.
123 *
124 * The solver method is band gaussian elimination with partial pivoting using the LU-factorized matrix A. The implementation is LAPACK method dpttrs.
125 *
126 * @param[in, out] b A 2D Kokkos::View storing the multiple right-hand sides of the problem and receiving the corresponding solution.
127 * @param transpose Choose between the direct or transposed version of the linear problem (unused for a symmetric problem).
128 */
129 void solve(MultiRHS const b, bool const) const override
130 {
131 assert(b.extent(0) == size());
132 auto q_device = m_q.d_view;
133 auto d = Kokkos::subview(q_device, 0, Kokkos::ALL);
134 auto e = Kokkos::
135 subview(q_device, 1, Kokkos::pair<int, int>(0, q_device.extent_int(1) - 1));
136 Kokkos::RangePolicy<ExecSpace> const policy(0, b.extent(1));
137 Kokkos::parallel_for(
138 "pttrs",
139 policy,
140 KOKKOS_LAMBDA(const int i) {
141 auto sub_b = Kokkos::subview(b, Kokkos::ALL, i);
142 KokkosBatched::SerialPttrs<
143 KokkosBatched::Uplo::Lower,
144 KokkosBatched::Algo::Pttrs::Unblocked>::invoke(d, e, sub_b);
145 });
146 }
147};
148
149} // namespace ddc::detail
The top-level namespace of DDC.