12#include <Kokkos_Core.hpp>
13#include <Kokkos_DualView.hpp>
15#include "splines_linear_problem.hpp"
16#include "splines_linear_problem_dense.hpp"
18namespace ddc::detail {
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35template <
class ExecSpace>
36class SplinesLinearProblem2x2Blocks :
public SplinesLinearProblem<ExecSpace>
39 using typename SplinesLinearProblem<ExecSpace>::MultiRHS;
40 using SplinesLinearProblem<ExecSpace>::size;
43
44
45
46
51 Kokkos::View<
int*, Kokkos::LayoutRight,
typename ExecSpace::memory_space> m_rows_idx;
52 Kokkos::View<
int*, Kokkos::LayoutRight,
typename ExecSpace::memory_space> m_cols_idx;
53 Kokkos::View<
double*, Kokkos::LayoutRight,
typename ExecSpace::memory_space> m_values;
55 Coo() : m_nrows(0), m_ncols(0) {}
57 Coo(std::size_t
const nrows_,
58 std::size_t
const ncols_,
59 Kokkos::View<
int*, Kokkos::LayoutRight,
typename ExecSpace::memory_space> rows_idx_,
60 Kokkos::View<
int*, Kokkos::LayoutRight,
typename ExecSpace::memory_space> cols_idx_,
61 Kokkos::View<
double*, Kokkos::LayoutRight,
typename ExecSpace::memory_space> values_)
64 , m_rows_idx(std::move(rows_idx_))
65 , m_cols_idx(std::move(cols_idx_))
66 , m_values(std::move(values_))
68 assert(m_rows_idx.extent(0) == m_cols_idx.extent(0));
69 assert(m_rows_idx.extent(0) == m_values.extent(0));
72 KOKKOS_FUNCTION std::size_t nnz()
const
74 return m_values.extent(0);
77 KOKKOS_FUNCTION std::size_t nrows()
const
82 KOKKOS_FUNCTION std::size_t ncols()
const
87 KOKKOS_FUNCTION Kokkos::View<
int*, Kokkos::LayoutRight,
typename ExecSpace::memory_space>
93 KOKKOS_FUNCTION Kokkos::View<
int*, Kokkos::LayoutRight,
typename ExecSpace::memory_space>
99 KOKKOS_FUNCTION Kokkos::View<
double*, Kokkos::LayoutRight,
typename ExecSpace::memory_space>
107 std::unique_ptr<SplinesLinearProblem<ExecSpace>> m_top_left_block;
108 Kokkos::DualView<
double**, Kokkos::LayoutRight,
typename ExecSpace::memory_space>
110 Coo m_top_right_block_coo;
111 Kokkos::DualView<
double**, Kokkos::LayoutRight,
typename ExecSpace::memory_space>
113 Coo m_bottom_left_block_coo;
114 std::unique_ptr<SplinesLinearProblem<ExecSpace>> m_bottom_right_block;
118
119
120
121
122
123 explicit SplinesLinearProblem2x2Blocks(
124 std::size_t
const mat_size,
125 std::unique_ptr<SplinesLinearProblem<ExecSpace>> top_left_block)
126 : SplinesLinearProblem<ExecSpace>(mat_size)
127 , m_top_left_block(std::move(top_left_block))
130 m_top_left_block->size(),
131 mat_size - m_top_left_block->size())
132 , m_bottom_left_block(
134 mat_size - m_top_left_block->size(),
135 m_top_left_block->size())
136 , m_bottom_right_block(
137 new SplinesLinearProblemDense<ExecSpace>(mat_size - m_top_left_block->size()))
139 assert(m_top_left_block->size() <= mat_size);
141 Kokkos::deep_copy(m_top_right_block.view_host(), 0.);
142 Kokkos::deep_copy(m_bottom_left_block.view_host(), 0.);
145 double get_element(std::size_t
const i, std::size_t
const j)
const override
150 std::size_t
const nq = m_top_left_block->size();
151 if (i < nq && j < nq) {
152 return m_top_left_block->get_element(i, j);
155 if (i >= nq && j >= nq) {
156 return m_bottom_right_block->get_element(i - nq, j - nq);
160 return m_top_right_block.view_host()(i, j - nq);
163 return m_bottom_left_block.view_host()(i - nq, j);
166 void set_element(std::size_t
const i, std::size_t
const j,
double const aij)
override
171 std::size_t
const nq = m_top_left_block->size();
172 if (i < nq && j < nq) {
173 m_top_left_block->set_element(i, j, aij);
174 }
else if (i >= nq && j >= nq) {
175 m_bottom_right_block->set_element(i - nq, j - nq, aij);
176 }
else if (j >= nq) {
177 m_top_right_block.view_host()(i, j - nq) = aij;
179 m_bottom_left_block.view_host()(i - nq, j) = aij;
184
185
186
187
188
189
190
191
192
193
194
196 Kokkos::View<
double const**, Kokkos::LayoutRight,
typename ExecSpace::memory_space>
198 double const tol = 1e-14)
200 Kokkos::View<
int*, Kokkos::LayoutRight,
typename ExecSpace::memory_space> rows_idx(
201 "ddc_splines_coo_rows_idx",
202 dense_matrix.extent(0) * dense_matrix.extent(1));
203 Kokkos::View<
int*, Kokkos::LayoutRight,
typename ExecSpace::memory_space> cols_idx(
204 "ddc_splines_coo_cols_idx",
205 dense_matrix.extent(0) * dense_matrix.extent(1));
206 Kokkos::View<
double*, Kokkos::LayoutRight,
typename ExecSpace::memory_space>
207 values(
"ddc_splines_coo_values", dense_matrix.extent(0) * dense_matrix.extent(1));
209 Kokkos::DualView<std::size_t, Kokkos::LayoutRight,
typename ExecSpace::memory_space>
210 n_nonzeros(
"ddc_splines_n_nonzeros");
211 n_nonzeros.view_host()() = 0;
212 n_nonzeros.modify_host();
213 n_nonzeros.sync_device();
215 auto const n_nonzeros_device = n_nonzeros.view_device();
216 Kokkos::parallel_for(
218 Kokkos::RangePolicy(ExecSpace(), 0, 1),
219 KOKKOS_LAMBDA(
int const) {
220 for (
int i = 0; i < dense_matrix.extent(0); ++i) {
221 for (
int j = 0; j < dense_matrix.extent(1); ++j) {
222 double const aij = dense_matrix(i, j);
223 if (Kokkos::abs(aij) >= tol) {
224 rows_idx(n_nonzeros_device()) = i;
225 cols_idx(n_nonzeros_device()) = j;
226 values(n_nonzeros_device()) = aij;
227 n_nonzeros_device()++;
232 n_nonzeros.modify_device();
233 n_nonzeros.sync_host();
234 Kokkos::resize(rows_idx, n_nonzeros.view_host()());
235 Kokkos::resize(cols_idx, n_nonzeros.view_host()());
236 Kokkos::resize(values, n_nonzeros.view_host()());
238 return Coo(dense_matrix.extent(0), dense_matrix.extent(1), rows_idx, cols_idx, values);
243 void compute_schur_complement()
245 auto const bottom_left_block = m_bottom_left_block.view_host();
246 auto const top_right_block = m_top_right_block.view_host();
247 Kokkos::parallel_for(
248 "compute_schur_complement",
249 Kokkos::MDRangePolicy<Kokkos::DefaultHostExecutionSpace, Kokkos::Rank<2>>(
251 {m_bottom_right_block->size(), m_bottom_right_block->size()}),
252 [&](
int const i,
int const j) {
254 for (
int l = 0; l < m_top_left_block->size(); ++l) {
255 val += bottom_left_block(i, l) * top_right_block(l, j);
258 ->set_element(i, j, m_bottom_right_block->get_element(i, j) - val);
264
265
266
267
268
269
270
271
272
273
274
275
276
277 void setup_solver()
override
280 m_top_left_block->setup_solver();
283 m_top_right_block.modify_host();
284 m_top_right_block.sync_device();
285 m_top_left_block->solve(m_top_right_block.view_device(),
false);
286 m_top_right_block_coo = dense2coo(m_top_right_block.view_device());
287 m_top_right_block.modify_device();
288 m_top_right_block.sync_host();
291 m_bottom_left_block.modify_host();
292 m_bottom_left_block.sync_device();
293 m_bottom_left_block_coo = dense2coo(m_bottom_left_block.view_device());
296 compute_schur_complement();
297 m_bottom_right_block->setup_solver();
301
302
303
304
305
306
307
308
309
310
311
312
313 void spdm_minus1_1(Coo LinOp, MultiRHS
const x, MultiRHS
const y,
bool const transpose =
false)
316 assert((!transpose && LinOp.nrows() == y.extent(0))
317 || (transpose && LinOp.ncols() == y.extent(0)));
318 assert((!transpose && LinOp.ncols() == x.extent(0))
319 || (transpose && LinOp.nrows() == x.extent(0)));
320 assert(x.extent(1) == y.extent(1));
323 Kokkos::parallel_for(
324 "ddc_splines_spdm_minus1_1",
325 Kokkos::RangePolicy(ExecSpace(), 0, y.extent(1)),
326 KOKKOS_LAMBDA(
int const j) {
327 for (
int nz_idx = 0; nz_idx < LinOp.nnz(); ++nz_idx) {
328 int const i = LinOp.rows_idx()(nz_idx);
329 int const k = LinOp.cols_idx()(nz_idx);
330 y(i, j) -= LinOp.values()(nz_idx) * x(k, j);
334 Kokkos::parallel_for(
335 "ddc_splines_spdm_minus1_1_tr",
336 Kokkos::RangePolicy(ExecSpace(), 0, y.extent(1)),
337 KOKKOS_LAMBDA(
int const j) {
338 for (
int nz_idx = 0; nz_idx < LinOp.nnz(); ++nz_idx) {
339 int const i = LinOp.rows_idx()(nz_idx);
340 int const k = LinOp.cols_idx()(nz_idx);
341 y(k, j) -= LinOp.values()(nz_idx) * x(i, j);
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365 void solve(MultiRHS
const b,
bool const transpose)
const override
367 assert(b.extent(0) == size());
369 MultiRHS
const b1 = Kokkos::
371 std::pair<std::size_t, std::size_t>(0, m_top_left_block->size()),
373 MultiRHS
const b2 = Kokkos::
375 std::pair<std::size_t, std::size_t>(m_top_left_block->size(), b.extent(0)),
378 m_top_left_block->solve(b1,
false);
379 spdm_minus1_1(m_bottom_left_block_coo, b1, b2);
380 m_bottom_right_block->solve(b2,
false);
381 spdm_minus1_1(m_top_right_block_coo, b2, b1);
383 spdm_minus1_1(m_top_right_block_coo, b1, b2,
true);
384 m_bottom_right_block->solve(b2,
true);
385 spdm_minus1_1(m_bottom_left_block_coo, b2, b1,
true);
386 m_top_left_block->solve(b1,
true);
The top-level namespace of DDC.