DDC 0.1.0
Loading...
Searching...
No Matches
splines_linear_problem_dense.hpp
1// Copyright (C) The DDC development team, see COPYRIGHT.md file
2//
3// SPDX-License-Identifier: MIT
4
5#pragma once
6
7#include <cassert>
8#include <cstddef>
9#include <stdexcept>
10#include <string>
11
12#include <Kokkos_Core.hpp>
13#include <Kokkos_DualView.hpp>
14
15#if __has_include(<mkl_lapacke.h>)
16#include <mkl_lapacke.h>
17#else
18#include <lapacke.h>
19#endif
20
21#include <KokkosBatched_Getrs.hpp>
22#include <KokkosBatched_Util.hpp>
23
24#include "splines_linear_problem.hpp"
25
26namespace ddc::detail {
27
28/**
29 * @brief A dense linear problem dedicated to the computation of a spline approximation.
30 *
31 * The storage format is dense row-major. Lapack is used to perform every matrix and linear solver-related operations.
32 *
33 * @tparam ExecSpace The Kokkos::ExecutionSpace on which operations related to the matrix are supposed to be performed.
34 */
35template <class ExecSpace>
36class SplinesLinearProblemDense : public SplinesLinearProblem<ExecSpace>
37{
38public:
39 using typename SplinesLinearProblem<ExecSpace>::MultiRHS;
40 using SplinesLinearProblem<ExecSpace>::size;
41
42protected:
43 Kokkos::DualView<double**, Kokkos::LayoutRight, typename ExecSpace::memory_space> m_a;
44 Kokkos::DualView<int*, typename ExecSpace::memory_space> m_ipiv;
45
46public:
47 /**
48 * @brief SplinesLinearProblemDense constructor.
49 *
50 * @param mat_size The size of one of the dimensions of the square matrix.
51 */
52 explicit SplinesLinearProblemDense(std::size_t const mat_size)
53 : SplinesLinearProblem<ExecSpace>(mat_size)
54 , m_a("a", mat_size, mat_size)
55 , m_ipiv("ipiv", mat_size)
56 {
57 Kokkos::deep_copy(m_a.h_view, 0.);
58 }
59
60 double get_element(std::size_t const i, std::size_t const j) const override
61 {
62 assert(i < size());
63 assert(j < size());
64 return m_a.h_view(i, j);
65 }
66
67 void set_element(std::size_t const i, std::size_t const j, double const aij) override
68 {
69 assert(i < size());
70 assert(j < size());
71 m_a.h_view(i, j) = aij;
72 }
73
74 /**
75 * @brief Perform a pre-process operation on the solver. Must be called after filling the matrix.
76 *
77 * LU-factorize the matrix A and store the pivots using the LAPACK dgetrf() implementation.
78 */
79 void setup_solver() override
80 {
81 int const info = LAPACKE_dgetrf(
82 LAPACK_ROW_MAJOR,
83 size(),
84 size(),
85 m_a.h_view.data(),
86 size(),
87 m_ipiv.h_view.data());
88 if (info != 0) {
89 throw std::runtime_error(
90 "LAPACKE_dgetrf failed with error code " + std::to_string(info));
91 }
92
93 // Convert 1-based index to 0-based index
94 for (int i = 0; i < size(); ++i) {
95 m_ipiv.h_view(i) -= 1;
96 }
97
98 // Push on device
99 m_a.modify_host();
100 m_a.sync_device();
101 m_ipiv.modify_host();
102 m_ipiv.sync_device();
103 }
104
105 /**
106 * @brief Solve the multiple right-hand sides linear problem Ax=b or its transposed version A^tx=b inplace.
107 *
108 * The solver method is gaussian elimination with partial pivoting using the LU-factorized matrix A. The implementation is LAPACK method dgetrs.
109 *
110 * @param[in, out] b A 2D Kokkos::View storing the multiple right-hand sides of the problem and receiving the corresponding solution.
111 * @param transpose Choose between the direct or transposed version of the linear problem.
112 */
113 void solve(MultiRHS const b, bool const transpose) const override
114 {
115 assert(b.extent(0) == size());
116
117 // For order 1 splines, size() can be 0 then we bypass the solver call.
118 if (size() == 0) {
119 return;
120 }
121
122 auto a_device = m_a.d_view;
123 auto ipiv_device = m_ipiv.d_view;
124
125 Kokkos::RangePolicy<ExecSpace> const policy(0, b.extent(1));
126
127 if (transpose) {
128 Kokkos::parallel_for(
129 "gerts",
130 policy,
131 KOKKOS_LAMBDA(const int i) {
132 auto sub_b = Kokkos::subview(b, Kokkos::ALL, i);
133 KokkosBatched::SerialGetrs<
134 KokkosBatched::Trans::Transpose,
135 KokkosBatched::Algo::Getrs::Unblocked>::
136 invoke(a_device, ipiv_device, sub_b);
137 });
138 } else {
139 Kokkos::parallel_for(
140 "gerts",
141 policy,
142 KOKKOS_LAMBDA(const int i) {
143 auto sub_b = Kokkos::subview(b, Kokkos::ALL, i);
144 KokkosBatched::SerialGetrs<
145 KokkosBatched::Trans::NoTranspose,
146 KokkosBatched::Algo::Getrs::Unblocked>::
147 invoke(a_device, ipiv_device, sub_b);
148 });
149 }
150 }
151};
152
153} // namespace ddc::detail
The top-level namespace of DDC.