DDC 0.1.0
Loading...
Searching...
No Matches
splines_linear_problem.hpp
1// Copyright (C) The DDC development team, see COPYRIGHT.md file
2//
3// SPDX-License-Identifier: MIT
4
5#pragma once
6
7#include <cassert>
8#include <cstddef>
9#include <iomanip>
10#include <ostream>
11
12#include <Kokkos_Core.hpp>
13
14namespace ddc::detail {
15
16/**
17 * @brief The parent class for linear problems dedicated to the computation of spline approximations.
18 *
19 * Store a square matrix and provide method to solve a multiple right-hand sides linear problem.
20 * Implementations may have different storage formats, filling methods and multiple right-hand sides linear solvers.
21 */
22template <class ExecSpace>
23class SplinesLinearProblem
24{
25public:
26 /// @brief The type of a Kokkos::View storing multiple right-hand sides.
27 using MultiRHS = Kokkos::View<double**, Kokkos::LayoutRight, ExecSpace>;
28
29private:
30 std::size_t m_size;
31
32protected:
33 explicit SplinesLinearProblem(const std::size_t size) : m_size(size) {}
34
35public:
36 SplinesLinearProblem(SplinesLinearProblem const& x) = delete;
37
38 SplinesLinearProblem(SplinesLinearProblem&& x) = delete;
39
40 /// @brief Destruct
41 virtual ~SplinesLinearProblem() = default;
42
43 SplinesLinearProblem& operator=(SplinesLinearProblem const& x) = delete;
44
45 SplinesLinearProblem& operator=(SplinesLinearProblem&& x) = delete;
46
47 /**
48 * @brief Get an element of the matrix at indexes i, j. It must not be called after `setup_solver`.
49 *
50 * @param i The row index of the desired element.
51 * @param j The column index of the desired element.
52 *
53 * @return The value of the element of the matrix.
54 */
55 virtual double get_element(std::size_t i, std::size_t j) const = 0;
56
57 /**
58 * @brief Set an element of the matrix at indexes i, j. It must not be called after `setup_solver`.
59 *
60 * @param i The row index of the setted element.
61 * @param j The column index of the setted element.
62 * @param aij The value to set in the element of the matrix.
63 */
64 virtual void set_element(std::size_t i, std::size_t j, double aij) = 0;
65
66 /**
67 * @brief Perform a pre-process operation on the solver. Must be called after filling the matrix.
68 */
69 virtual void setup_solver() = 0;
70
71 /**
72 * @brief Solve the multiple right-hand sides linear problem Ax=b or its transposed version A^tx=b inplace.
73 *
74 * @param[in, out] multi_rhs A 2D Kokkos::View storing the multiple right-hand sides of the problem and receiving the corresponding solution.
75 * @param transpose Choose between the direct or transposed version of the linear problem.
76 */
77 virtual void solve(MultiRHS b, bool transpose) const = 0;
78
79 /**
80 * @brief Get the size of the square matrix in one of its dimensions.
81 *
82 * @return The size of the matrix in one of its dimensions.
83 */
84 std::size_t size() const
85 {
86 return m_size;
87 }
88
89 /**
90 * @brief Get the required number of rows of the multi-rhs view passed to solve().
91 *
92 * Implementations may require a number of rows larger than what `size` returns for optimization purposes.
93 *
94 * @return The required number of rows of the multi-rhs view. It is guaranteed to be greater or equal to `size`.
95 */
96 std::size_t required_number_of_rhs_rows() const
97 {
98 std::size_t const nrows = impl_required_number_of_rhs_rows();
99 assert(nrows >= size());
100 return nrows;
101 }
102
103private:
104 virtual std::size_t impl_required_number_of_rhs_rows() const
105 {
106 return m_size;
107 }
108};
109
110/**
111 * @brief Prints the matrix of a SplinesLinearProblem in a std::ostream. It must not be called after `setup_solver`.
112 *
113 * @param[out] os The stream in which the matrix is printed.
114 * @param[in] linear_problem The SplinesLinearProblem of the matrix to print.
115 *
116 * @return The stream in which the matrix is printed.
117**/
118template <class ExecSpace>
119std::ostream& operator<<(std::ostream& os, SplinesLinearProblem<ExecSpace> const& linear_problem)
120{
121 std::size_t const n = linear_problem.size();
122 for (std::size_t i = 0; i < n; ++i) {
123 for (std::size_t j = 0; j < n; ++j) {
124 os << std::fixed << std::setprecision(3) << std::setw(10)
125 << linear_problem.get_element(i, j);
126 }
127 os << "\n";
128 }
129 return os;
130}
131
132} // namespace ddc::detail
The top-level namespace of DDC.