DDC 0.4.1
Loading...
Searching...
No Matches
splines_linear_problem_dense.hpp
1// Copyright (C) The DDC development team, see COPYRIGHT.md file
2//
3// SPDX-License-Identifier: MIT
4
5#pragma once
6
7#include <cassert>
8#include <cstddef>
9#include <stdexcept>
10#include <string>
11
12#include <Kokkos_Core.hpp>
13#include <Kokkos_DualView.hpp>
14
15#if __has_include(<mkl_lapacke.h>)
16#include <mkl_lapacke.h>
17#else
18#include <lapacke.h>
19#endif
20
21#include <KokkosBatched_Util.hpp>
22
23#include "kokkos-kernels-ext/KokkosBatched_Getrs.hpp"
24
25#include "splines_linear_problem.hpp"
26
27namespace ddc::detail {
28
29/**
30 * @brief A dense linear problem dedicated to the computation of a spline approximation.
31 *
32 * The storage format is dense row-major. Lapack is used to perform every matrix and linear solver-related operations.
33 *
34 * @tparam ExecSpace The Kokkos::ExecutionSpace on which operations related to the matrix are supposed to be performed.
35 */
36template <class ExecSpace>
37class SplinesLinearProblemDense : public SplinesLinearProblem<ExecSpace>
38{
39public:
40 using typename SplinesLinearProblem<ExecSpace>::MultiRHS;
41 using SplinesLinearProblem<ExecSpace>::size;
42
43protected:
44 Kokkos::DualView<double**, Kokkos::LayoutRight, typename ExecSpace::memory_space> m_a;
45 Kokkos::DualView<int*, typename ExecSpace::memory_space> m_ipiv;
46
47public:
48 /**
49 * @brief SplinesLinearProblemDense constructor.
50 *
51 * @param mat_size The size of one of the dimensions of the square matrix.
52 */
53 explicit SplinesLinearProblemDense(std::size_t const mat_size)
54 : SplinesLinearProblem<ExecSpace>(mat_size)
55 , m_a("a", mat_size, mat_size)
56 , m_ipiv("ipiv", mat_size)
57 {
58 Kokkos::deep_copy(m_a.h_view, 0.);
59 }
60
61 double get_element(std::size_t const i, std::size_t const j) const override
62 {
63 assert(i < size());
64 assert(j < size());
65 return m_a.h_view(i, j);
66 }
67
68 void set_element(std::size_t const i, std::size_t const j, double const aij) override
69 {
70 assert(i < size());
71 assert(j < size());
72 m_a.h_view(i, j) = aij;
73 }
74
75 /**
76 * @brief Perform a pre-process operation on the solver. Must be called after filling the matrix.
77 *
78 * LU-factorize the matrix A and store the pivots using the LAPACK dgetrf() implementation.
79 */
80 void setup_solver() override
81 {
82 int const info = LAPACKE_dgetrf(
83 LAPACK_ROW_MAJOR,
84 size(),
85 size(),
86 m_a.h_view.data(),
87 size(),
88 m_ipiv.h_view.data());
89 if (info != 0) {
90 throw std::runtime_error(
91 "LAPACKE_dgetrf failed with error code " + std::to_string(info));
92 }
93
94 // Convert 1-based index to 0-based index
95 for (int i = 0; i < size(); ++i) {
96 m_ipiv.h_view(i) -= 1;
97 }
98
99 // Push on device
100 m_a.modify_host();
101 m_a.sync_device();
102 m_ipiv.modify_host();
103 m_ipiv.sync_device();
104 }
105
106 /**
107 * @brief Solve the multiple right-hand sides linear problem Ax=b or its transposed version A^tx=b inplace.
108 *
109 * The solver method is gaussian elimination with partial pivoting using the LU-factorized matrix A. The implementation is LAPACK method dgetrs.
110 *
111 * @param[in, out] b A 2D Kokkos::View storing the multiple right-hand sides of the problem and receiving the corresponding solution.
112 * @param transpose Choose between the direct or transposed version of the linear problem.
113 */
114 void solve(MultiRHS const b, bool const transpose) const override
115 {
116 assert(b.extent(0) == size());
117
118 // For order 1 splines, size() can be 0 then we bypass the solver call.
119 if (size() == 0) {
120 return;
121 }
122
123 auto a_device = m_a.d_view;
124 auto ipiv_device = m_ipiv.d_view;
125
126 Kokkos::RangePolicy<ExecSpace> const policy(0, b.extent(1));
127
128 if (transpose) {
129 Kokkos::parallel_for(
130 "gerts",
131 policy,
132 KOKKOS_LAMBDA(const int i) {
133 auto sub_b = Kokkos::subview(b, Kokkos::ALL, i);
134 KokkosBatched::SerialGetrs<
135 KokkosBatched::Trans::Transpose,
136 KokkosBatched::Algo::Level3::Unblocked>::
137 invoke(a_device, ipiv_device, sub_b);
138 });
139 } else {
140 Kokkos::parallel_for(
141 "gerts",
142 policy,
143 KOKKOS_LAMBDA(const int i) {
144 auto sub_b = Kokkos::subview(b, Kokkos::ALL, i);
145 KokkosBatched::SerialGetrs<
146 KokkosBatched::Trans::NoTranspose,
147 KokkosBatched::Algo::Level3::Unblocked>::
148 invoke(a_device, ipiv_device, sub_b);
149 });
150 }
151 }
152};
153
154} // namespace ddc::detail
The top-level namespace of DDC.