13#include <Kokkos_Core.hpp>
14#include <Kokkos_DualView.hpp>
16#include "splines_linear_problem.hpp"
17#include "splines_linear_problem_dense.hpp"
19namespace ddc::detail {
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36template <
class ExecSpace>
37class SplinesLinearProblem2x2Blocks :
public SplinesLinearProblem<ExecSpace>
40 using typename SplinesLinearProblem<ExecSpace>::MultiRHS;
41 using SplinesLinearProblem<ExecSpace>::size;
44
45
46
47
52 Kokkos::View<
int*, Kokkos::LayoutRight,
typename ExecSpace::memory_space> m_rows_idx;
53 Kokkos::View<
int*, Kokkos::LayoutRight,
typename ExecSpace::memory_space> m_cols_idx;
54 Kokkos::View<
double*, Kokkos::LayoutRight,
typename ExecSpace::memory_space> m_values;
56 Coo() : m_nrows(0), m_ncols(0) {}
58 Coo(std::size_t
const nrows_,
59 std::size_t
const ncols_,
60 Kokkos::View<
int*, Kokkos::LayoutRight,
typename ExecSpace::memory_space> rows_idx_,
61 Kokkos::View<
int*, Kokkos::LayoutRight,
typename ExecSpace::memory_space> cols_idx_,
62 Kokkos::View<
double*, Kokkos::LayoutRight,
typename ExecSpace::memory_space> values_)
65 , m_rows_idx(std::move(rows_idx_))
66 , m_cols_idx(std::move(cols_idx_))
67 , m_values(std::move(values_))
69 assert(m_rows_idx.extent(0) == m_cols_idx.extent(0));
70 assert(m_rows_idx.extent(0) == m_values.extent(0));
73 KOKKOS_FUNCTION std::size_t nnz()
const
75 return m_values.extent(0);
78 KOKKOS_FUNCTION std::size_t nrows()
const
83 KOKKOS_FUNCTION std::size_t ncols()
const
88 KOKKOS_FUNCTION Kokkos::View<
int*, Kokkos::LayoutRight,
typename ExecSpace::memory_space>
94 KOKKOS_FUNCTION Kokkos::View<
int*, Kokkos::LayoutRight,
typename ExecSpace::memory_space>
100 KOKKOS_FUNCTION Kokkos::View<
double*, Kokkos::LayoutRight,
typename ExecSpace::memory_space>
108 std::unique_ptr<SplinesLinearProblem<ExecSpace>> m_top_left_block;
109 Kokkos::DualView<
double**, Kokkos::LayoutRight,
typename ExecSpace::memory_space>
111 Coo m_top_right_block_coo;
112 Kokkos::DualView<
double**, Kokkos::LayoutRight,
typename ExecSpace::memory_space>
114 Coo m_bottom_left_block_coo;
115 std::unique_ptr<SplinesLinearProblem<ExecSpace>> m_bottom_right_block;
119
120
121
122
123
124 explicit SplinesLinearProblem2x2Blocks(
125 std::size_t
const mat_size,
126 std::unique_ptr<SplinesLinearProblem<ExecSpace>> top_left_block)
127 : SplinesLinearProblem<ExecSpace>(mat_size)
128 , m_top_left_block(std::move(top_left_block))
131 m_top_left_block->size(),
132 mat_size - m_top_left_block->size())
133 , m_bottom_left_block(
135 mat_size - m_top_left_block->size(),
136 m_top_left_block->size())
137 , m_bottom_right_block(
138 new SplinesLinearProblemDense<ExecSpace>(mat_size - m_top_left_block->size()))
140 assert(m_top_left_block->size() <= mat_size);
142 Kokkos::deep_copy(m_top_right_block.view_host(), 0.);
143 Kokkos::deep_copy(m_bottom_left_block.view_host(), 0.);
146 double get_element(std::size_t
const i, std::size_t
const j)
const override
151 std::size_t
const nq = m_top_left_block->size();
152 if (i < nq && j < nq) {
153 return m_top_left_block->get_element(i, j);
156 if (i >= nq && j >= nq) {
157 return m_bottom_right_block->get_element(i - nq, j - nq);
161 return m_top_right_block.view_host()(i, j - nq);
164 return m_bottom_left_block.view_host()(i - nq, j);
167 void set_element(std::size_t
const i, std::size_t
const j,
double const aij)
override
172 std::size_t
const nq = m_top_left_block->size();
173 if (i < nq && j < nq) {
174 m_top_left_block->set_element(i, j, aij);
175 }
else if (i >= nq && j >= nq) {
176 m_bottom_right_block->set_element(i - nq, j - nq, aij);
177 }
else if (j >= nq) {
178 m_top_right_block.view_host()(i, j - nq) = aij;
180 m_bottom_left_block.view_host()(i - nq, j) = aij;
185
186
187
188
189
190
191
192
193
194
195
197 Kokkos::View<
const double**, Kokkos::LayoutRight,
typename ExecSpace::memory_space>
199 double const tol = 1e-14)
201 Kokkos::View<
int*, Kokkos::LayoutRight,
typename ExecSpace::memory_space> rows_idx(
202 "ddc_splines_coo_rows_idx",
203 dense_matrix.extent(0) * dense_matrix.extent(1));
204 Kokkos::View<
int*, Kokkos::LayoutRight,
typename ExecSpace::memory_space> cols_idx(
205 "ddc_splines_coo_cols_idx",
206 dense_matrix.extent(0) * dense_matrix.extent(1));
207 Kokkos::View<
double*, Kokkos::LayoutRight,
typename ExecSpace::memory_space>
208 values(
"ddc_splines_coo_values", dense_matrix.extent(0) * dense_matrix.extent(1));
210 Kokkos::DualView<std::size_t, Kokkos::LayoutRight,
typename ExecSpace::memory_space>
211 n_nonzeros(
"ddc_splines_n_nonzeros");
212 n_nonzeros.view_host()() = 0;
213 n_nonzeros.modify_host();
214 n_nonzeros.sync_device();
216 auto const n_nonzeros_device = n_nonzeros.view_device();
217 Kokkos::parallel_for(
219 Kokkos::RangePolicy(ExecSpace(), 0, 1),
220 KOKKOS_LAMBDA(
const int) {
221 for (
int i = 0; i < dense_matrix.extent(0); ++i) {
222 for (
int j = 0; j < dense_matrix.extent(1); ++j) {
223 double const aij = dense_matrix(i, j);
224 if (Kokkos::abs(aij) >= tol) {
225 rows_idx(n_nonzeros_device()) = i;
226 cols_idx(n_nonzeros_device()) = j;
227 values(n_nonzeros_device()) = aij;
228 n_nonzeros_device()++;
233 n_nonzeros.modify_device();
234 n_nonzeros.sync_host();
235 Kokkos::resize(rows_idx, n_nonzeros.view_host()());
236 Kokkos::resize(cols_idx, n_nonzeros.view_host()());
237 Kokkos::resize(values, n_nonzeros.view_host()());
239 return Coo(dense_matrix.extent(0), dense_matrix.extent(1), rows_idx, cols_idx, values);
244 void compute_schur_complement()
246 auto const bottom_left_block = m_bottom_left_block.view_host();
247 auto const top_right_block = m_top_right_block.view_host();
248 Kokkos::parallel_for(
249 "compute_schur_complement",
250 Kokkos::MDRangePolicy<Kokkos::DefaultHostExecutionSpace, Kokkos::Rank<2>>(
252 {m_bottom_right_block->size(), m_bottom_right_block->size()}),
253 [&](
const int i,
const int j) {
255 for (
int l = 0; l < m_top_left_block->size(); ++l) {
256 val += bottom_left_block(i, l) * top_right_block(l, j);
259 ->set_element(i, j, m_bottom_right_block->get_element(i, j) - val);
265
266
267
268
269
270
271
272
273
274
275
276
277
278 void setup_solver()
override
281 m_top_left_block->setup_solver();
284 m_top_right_block.modify_host();
285 m_top_right_block.sync_device();
286 m_top_left_block->solve(m_top_right_block.view_device(),
false);
287 m_top_right_block_coo = dense2coo(m_top_right_block.view_device());
288 m_top_right_block.modify_device();
289 m_top_right_block.sync_host();
292 m_bottom_left_block.modify_host();
293 m_bottom_left_block.sync_device();
294 m_bottom_left_block_coo = dense2coo(m_bottom_left_block.view_device());
297 compute_schur_complement();
298 m_bottom_right_block->setup_solver();
302
303
304
305
306
307
308
309
310
311
312
313
314 void spdm_minus1_1(Coo LinOp, MultiRHS
const x, MultiRHS
const y,
bool const transpose =
false)
317 assert((!transpose && LinOp.nrows() == y.extent(0))
318 || (transpose && LinOp.ncols() == y.extent(0)));
319 assert((!transpose && LinOp.ncols() == x.extent(0))
320 || (transpose && LinOp.nrows() == x.extent(0)));
321 assert(x.extent(1) == y.extent(1));
324 Kokkos::parallel_for(
325 "ddc_splines_spdm_minus1_1",
326 Kokkos::RangePolicy(ExecSpace(), 0, y.extent(1)),
327 KOKKOS_LAMBDA(
const int j) {
328 for (
int nz_idx = 0; nz_idx < LinOp.nnz(); ++nz_idx) {
329 const int i = LinOp.rows_idx()(nz_idx);
330 const int k = LinOp.cols_idx()(nz_idx);
331 y(i, j) -= LinOp.values()(nz_idx) * x(k, j);
335 Kokkos::parallel_for(
336 "ddc_splines_spdm_minus1_1_tr",
337 Kokkos::RangePolicy(ExecSpace(), 0, y.extent(1)),
338 KOKKOS_LAMBDA(
const int j) {
339 for (
int nz_idx = 0; nz_idx < LinOp.nnz(); ++nz_idx) {
340 const int i = LinOp.rows_idx()(nz_idx);
341 const int k = LinOp.cols_idx()(nz_idx);
342 y(k, j) -= LinOp.values()(nz_idx) * x(i, j);
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366 void solve(MultiRHS
const b,
bool const transpose)
const override
368 assert(b.extent(0) == size());
370 MultiRHS
const b1 = Kokkos::
372 std::pair<std::size_t, std::size_t>(0, m_top_left_block->size()),
374 MultiRHS
const b2 = Kokkos::
376 std::pair<std::size_t, std::size_t>(m_top_left_block->size(), b.extent(0)),
379 m_top_left_block->solve(b1,
false);
380 spdm_minus1_1(m_bottom_left_block_coo, b1, b2);
381 m_bottom_right_block->solve(b2,
false);
382 spdm_minus1_1(m_top_right_block_coo, b2, b1);
384 spdm_minus1_1(m_top_right_block_coo, b1, b2,
true);
385 m_bottom_right_block->solve(b2,
true);
386 spdm_minus1_1(m_bottom_left_block_coo, b2, b1,
true);
387 m_top_left_block->solve(b1,
true);
The top-level namespace of DDC.