13#include <Kokkos_Core.hpp>
14#include <Kokkos_DualView.hpp>
16#include "splines_linear_problem.hpp"
17#include "splines_linear_problem_dense.hpp"
19namespace ddc::detail {
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36template <
class ExecSpace>
37class SplinesLinearProblem2x2Blocks :
public SplinesLinearProblem<ExecSpace>
40 using typename SplinesLinearProblem<ExecSpace>::MultiRHS;
41 using SplinesLinearProblem<ExecSpace>::size;
44
45
46
47
52 Kokkos::View<
int*, Kokkos::LayoutRight,
typename ExecSpace::memory_space> m_rows_idx;
53 Kokkos::View<
int*, Kokkos::LayoutRight,
typename ExecSpace::memory_space> m_cols_idx;
54 Kokkos::View<
double*, Kokkos::LayoutRight,
typename ExecSpace::memory_space> m_values;
56 Coo() : m_nrows(0), m_ncols(0) {}
58 Coo(std::size_t
const nrows_,
59 std::size_t
const ncols_,
60 Kokkos::View<
int*, Kokkos::LayoutRight,
typename ExecSpace::memory_space> rows_idx_,
61 Kokkos::View<
int*, Kokkos::LayoutRight,
typename ExecSpace::memory_space> cols_idx_,
62 Kokkos::View<
double*, Kokkos::LayoutRight,
typename ExecSpace::memory_space> values_)
65 , m_rows_idx(std::move(rows_idx_))
66 , m_cols_idx(std::move(cols_idx_))
67 , m_values(std::move(values_))
69 assert(m_rows_idx.extent(0) == m_cols_idx.extent(0));
70 assert(m_rows_idx.extent(0) == m_values.extent(0));
73 KOKKOS_FUNCTION std::size_t nnz()
const
75 return m_values.extent(0);
78 KOKKOS_FUNCTION std::size_t nrows()
const
83 KOKKOS_FUNCTION std::size_t ncols()
const
88 KOKKOS_FUNCTION Kokkos::View<
int*, Kokkos::LayoutRight,
typename ExecSpace::memory_space>
94 KOKKOS_FUNCTION Kokkos::View<
int*, Kokkos::LayoutRight,
typename ExecSpace::memory_space>
100 KOKKOS_FUNCTION Kokkos::View<
double*, Kokkos::LayoutRight,
typename ExecSpace::memory_space>
108 std::unique_ptr<SplinesLinearProblem<ExecSpace>> m_top_left_block;
109 Kokkos::DualView<
double**, Kokkos::LayoutRight,
typename ExecSpace::memory_space>
111 Coo m_top_right_block_coo;
112 Kokkos::DualView<
double**, Kokkos::LayoutRight,
typename ExecSpace::memory_space>
114 Coo m_bottom_left_block_coo;
115 std::unique_ptr<SplinesLinearProblem<ExecSpace>> m_bottom_right_block;
119
120
121
122
123
124 explicit SplinesLinearProblem2x2Blocks(
125 std::size_t
const mat_size,
126 std::unique_ptr<SplinesLinearProblem<ExecSpace>> top_left_block)
127 : SplinesLinearProblem<ExecSpace>(mat_size)
128 , m_top_left_block(std::move(top_left_block))
131 m_top_left_block->size(),
132 mat_size - m_top_left_block->size())
133 , m_bottom_left_block(
135 mat_size - m_top_left_block->size(),
136 m_top_left_block->size())
137 , m_bottom_right_block(
138 new SplinesLinearProblemDense<ExecSpace>(mat_size - m_top_left_block->size()))
140 assert(m_top_left_block->size() <= mat_size);
142 Kokkos::deep_copy(m_top_right_block.h_view, 0.);
143 Kokkos::deep_copy(m_bottom_left_block.h_view, 0.);
146 double get_element(std::size_t
const i, std::size_t
const j)
const override
151 std::size_t
const nq = m_top_left_block->size();
152 if (i < nq && j < nq) {
153 return m_top_left_block->get_element(i, j);
156 if (i >= nq && j >= nq) {
157 return m_bottom_right_block->get_element(i - nq, j - nq);
161 return m_top_right_block.h_view(i, j - nq);
164 return m_bottom_left_block.h_view(i - nq, j);
167 void set_element(std::size_t
const i, std::size_t
const j,
double const aij)
override
172 std::size_t
const nq = m_top_left_block->size();
173 if (i < nq && j < nq) {
174 m_top_left_block->set_element(i, j, aij);
175 }
else if (i >= nq && j >= nq) {
176 m_bottom_right_block->set_element(i - nq, j - nq, aij);
177 }
else if (j >= nq) {
178 m_top_right_block.h_view(i, j - nq) = aij;
180 m_bottom_left_block.h_view(i - nq, j) = aij;
185
186
187
188
189
190
191
192
193
194
195
197 Kokkos::View<
const double**, Kokkos::LayoutRight,
typename ExecSpace::memory_space>
199 double const tol = 1e-14)
201 Kokkos::View<
int*, Kokkos::LayoutRight,
typename ExecSpace::memory_space> rows_idx(
202 "ddc_splines_coo_rows_idx",
203 dense_matrix.extent(0) * dense_matrix.extent(1));
204 Kokkos::View<
int*, Kokkos::LayoutRight,
typename ExecSpace::memory_space> cols_idx(
205 "ddc_splines_coo_cols_idx",
206 dense_matrix.extent(0) * dense_matrix.extent(1));
207 Kokkos::View<
double*, Kokkos::LayoutRight,
typename ExecSpace::memory_space>
208 values(
"ddc_splines_coo_values", dense_matrix.extent(0) * dense_matrix.extent(1));
210 Kokkos::DualView<std::size_t, Kokkos::LayoutRight,
typename ExecSpace::memory_space>
211 n_nonzeros(
"ddc_splines_n_nonzeros");
212 n_nonzeros.h_view() = 0;
213 n_nonzeros.modify_host();
214 n_nonzeros.sync_device();
216 Kokkos::parallel_for(
218 Kokkos::RangePolicy(ExecSpace(), 0, 1),
219 KOKKOS_LAMBDA(
const int) {
220 for (
int i = 0; i < dense_matrix.extent(0); ++i) {
221 for (
int j = 0; j < dense_matrix.extent(1); ++j) {
222 double const aij = dense_matrix(i, j);
223 if (Kokkos::abs(aij) >= tol) {
224 rows_idx(n_nonzeros.d_view()) = i;
225 cols_idx(n_nonzeros.d_view()) = j;
226 values(n_nonzeros.d_view()) = aij;
227 n_nonzeros.d_view()++;
232 n_nonzeros.modify_device();
233 n_nonzeros.sync_host();
234 Kokkos::resize(rows_idx, n_nonzeros.h_view());
235 Kokkos::resize(cols_idx, n_nonzeros.h_view());
236 Kokkos::resize(values, n_nonzeros.h_view());
238 return Coo(dense_matrix.extent(0), dense_matrix.extent(1), rows_idx, cols_idx, values);
243 void compute_schur_complement()
245 Kokkos::parallel_for(
246 "compute_schur_complement",
247 Kokkos::MDRangePolicy<Kokkos::DefaultHostExecutionSpace, Kokkos::Rank<2>>(
249 {m_bottom_right_block->size(), m_bottom_right_block->size()}),
250 [&](
const int i,
const int j) {
252 for (
int l = 0; l < m_top_left_block->size(); ++l) {
253 val += m_bottom_left_block.h_view(i, l) * m_top_right_block.h_view(l, j);
256 ->set_element(i, j, m_bottom_right_block->get_element(i, j) - val);
262
263
264
265
266
267
268
269
270
271
272
273
274
275 void setup_solver()
override
278 m_top_left_block->setup_solver();
281 m_top_right_block.modify_host();
282 m_top_right_block.sync_device();
283 m_top_left_block->solve(m_top_right_block.d_view,
false);
284 m_top_right_block_coo = dense2coo(m_top_right_block.d_view);
285 m_top_right_block.modify_device();
286 m_top_right_block.sync_host();
289 m_bottom_left_block.modify_host();
290 m_bottom_left_block.sync_device();
291 m_bottom_left_block_coo = dense2coo(m_bottom_left_block.d_view);
294 compute_schur_complement();
295 m_bottom_right_block->setup_solver();
299
300
301
302
303
304
305
306
307
308
309
310
311 void spdm_minus1_1(Coo LinOp, MultiRHS
const x, MultiRHS
const y,
bool const transpose =
false)
314 assert((!transpose && LinOp.nrows() == y.extent(0))
315 || (transpose && LinOp.ncols() == y.extent(0)));
316 assert((!transpose && LinOp.ncols() == x.extent(0))
317 || (transpose && LinOp.nrows() == x.extent(0)));
318 assert(x.extent(1) == y.extent(1));
321 Kokkos::parallel_for(
322 "ddc_splines_spdm_minus1_1",
323 Kokkos::RangePolicy(ExecSpace(), 0, y.extent(1)),
324 KOKKOS_LAMBDA(
const int j) {
325 for (
int nz_idx = 0; nz_idx < LinOp.nnz(); ++nz_idx) {
326 const int i = LinOp.rows_idx()(nz_idx);
327 const int k = LinOp.cols_idx()(nz_idx);
328 y(i, j) -= LinOp.values()(nz_idx) * x(k, j);
332 Kokkos::parallel_for(
333 "ddc_splines_spdm_minus1_1_tr",
334 Kokkos::RangePolicy(ExecSpace(), 0, y.extent(1)),
335 KOKKOS_LAMBDA(
const int j) {
336 for (
int nz_idx = 0; nz_idx < LinOp.nnz(); ++nz_idx) {
337 const int i = LinOp.rows_idx()(nz_idx);
338 const int k = LinOp.cols_idx()(nz_idx);
339 y(k, j) -= LinOp.values()(nz_idx) * x(i, j);
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363 void solve(MultiRHS
const b,
bool const transpose)
const override
365 assert(b.extent(0) == size());
367 MultiRHS
const b1 = Kokkos::
369 std::pair<std::size_t, std::size_t>(0, m_top_left_block->size()),
371 MultiRHS
const b2 = Kokkos::
373 std::pair<std::size_t, std::size_t>(m_top_left_block->size(), b.extent(0)),
376 m_top_left_block->solve(b1,
false);
377 spdm_minus1_1(m_bottom_left_block_coo, b1, b2);
378 m_bottom_right_block->solve(b2,
false);
379 spdm_minus1_1(m_top_right_block_coo, b2, b1);
381 spdm_minus1_1(m_top_right_block_coo, b1, b2,
true);
382 m_bottom_right_block->solve(b2,
true);
383 spdm_minus1_1(m_bottom_left_block_coo, b2, b1,
true);
384 m_top_left_block->solve(b1,
true);
The top-level namespace of DDC.