DDC 0.10.0
Loading...
Searching...
No Matches
splines_linear_problem_pds_tridiag.hpp
1// Copyright (C) The DDC development team, see COPYRIGHT.md file
2//
3// SPDX-License-Identifier: MIT
4
5#pragma once
6
7#include <cassert>
8#if !defined(NDEBUG)
9# include <cmath>
10#endif
11#include <cstddef>
12#include <stdexcept>
13#include <string>
14#include <utility>
15
16#include <Kokkos_Core.hpp>
17#include <Kokkos_DualView.hpp>
18
19#if __has_include(<mkl_lapacke.h>)
20# include <mkl_lapacke.h>
21#else
22# include <lapacke.h>
23#endif
24
25#include <KokkosBatched_Pttrs.hpp>
26#include <KokkosBatched_Util.hpp>
27
28#include "splines_linear_problem.hpp"
29
30namespace ddc::detail {
31
32/**
33 * @brief A positive-definite symmetric tridiagonal linear problem dedicated to the computation of a spline approximation.
34 *
35 * The storage format is dense row-major. Lapack is used to perform every matrix and linear solver-related operations.
36 *
37 * Given the linear system A*x=b, we assume that A is a square (n by n)
38 * with 1 subdiagonal and 1 superdiagonal.
39 * All non-zero elements of A are stored in the rectangular matrix q, using
40 * the format required by DPTTRF (LAPACK): diagonal and superdiagonal of A are rows of q.
41 * q has 1 row for the diagonal and 1 row for the superdiagonal.
42 *
43 * @tparam ExecSpace The Kokkos::ExecutionSpace on which operations related to the matrix are supposed to be performed.
44 */
45template <class ExecSpace>
46class SplinesLinearProblemPDSTridiag : public SplinesLinearProblem<ExecSpace>
47{
48public:
49 using typename SplinesLinearProblem<ExecSpace>::MultiRHS;
50 using SplinesLinearProblem<ExecSpace>::size;
51
52protected:
53 Kokkos::DualView<double**, Kokkos::LayoutRight, typename ExecSpace::memory_space>
54 m_q; // pds tridiagonal matrix representation
55
56public:
57 /**
58 * @brief SplinesLinearProblemPDSTridiag constructor.
59 *
60 * @param mat_size The size of one of the dimensions of the square matrix.
61 */
62 explicit SplinesLinearProblemPDSTridiag(std::size_t const mat_size)
63 : SplinesLinearProblem<ExecSpace>(mat_size)
64 , m_q("q", 2, mat_size)
65 {
66 Kokkos::deep_copy(m_q.view_host(), 0.);
67 }
68
69 double get_element(std::size_t i, std::size_t j) const override
70 {
71 assert(i < size());
72 assert(j < size());
73
74 // Indices are swapped for an element on subdiagonal
75 if (i > j) {
76 std::swap(i, j);
77 }
78
79 if (j - i < 2) {
80 return m_q.view_host()(j - i, i);
81 }
82
83 return 0.0;
84 }
85
86 void set_element(std::size_t i, std::size_t j, double const aij) override
87 {
88 assert(i < size());
89 assert(j < size());
90
91 // Indices are swapped for an element on subdiagonal
92 if (i > j) {
93 std::swap(i, j);
94 }
95 if (j - i < 2) {
96 m_q.view_host()(j - i, i) = aij;
97 } else {
98 assert(std::fabs(aij) < 1e-20);
99 }
100 }
101
102 /**
103 * @brief Perform a pre-process operation on the solver. Must be called after filling the matrix.
104 *
105 * LU-factorize the matrix A and store the pivots using the LAPACK dpttrf() implementation.
106 */
107 void setup_solver() override
108 {
109 int const info = LAPACKE_dpttrf(
110 size(),
111 m_q.view_host().data(),
112 m_q.view_host().data() + m_q.view_host().stride(0));
113 if (info != 0) {
114 throw std::runtime_error(
115 "LAPACKE_dpttrf failed with error code " + std::to_string(info));
116 }
117
118 // Push on device
119 m_q.modify_host();
120 m_q.sync_device();
121 }
122
123 /**
124 * @brief Solve the multiple right-hand sides linear problem Ax=b or its transposed version A^tx=b inplace.
125 *
126 * The solver method is band gaussian elimination with partial pivoting using the LU-factorized matrix A. The implementation is LAPACK method dpttrs.
127 *
128 * @param[in, out] b A 2D Kokkos::View storing the multiple right-hand sides of the problem and receiving the corresponding solution.
129 * @param transpose Choose between the direct or transposed version of the linear problem (unused for a symmetric problem).
130 */
131 void solve(MultiRHS const b, bool const) const override
132 {
133 assert(b.extent(0) == size());
134 auto q_device = m_q.view_device();
135 auto d = Kokkos::subview(q_device, 0, Kokkos::ALL);
136 auto e = Kokkos::
137 subview(q_device, 1, Kokkos::pair<int, int>(0, q_device.extent_int(1) - 1));
138 Kokkos::RangePolicy<ExecSpace> const policy(0, b.extent(1));
139 Kokkos::parallel_for(
140 "pttrs",
141 policy,
142 KOKKOS_LAMBDA(int const i) {
143 auto sub_b = Kokkos::subview(b, Kokkos::ALL, i);
144 KokkosBatched::SerialPttrs<
145 KokkosBatched::Uplo::Lower,
146 KokkosBatched::Algo::Pttrs::Unblocked>::invoke(d, e, sub_b);
147 });
148 }
149};
150
151} // namespace ddc::detail
The top-level namespace of DDC.