1#ifndef TATAMI_CHUNKED_CUSTOM_CHUNK_COORDINATOR_HPP
2#define TATAMI_CHUNKED_CUSTOM_CHUNK_COORDINATOR_HPP
9#include "ChunkDimensionStats.hpp"
15namespace CustomChunkedMatrix_internal {
21template<
typename CachedValue_>
22using DenseSingleWorkspace = std::vector<CachedValue_>;
24template<
typename CachedValue_,
typename Index_>
25class SparseSingleWorkspace {
27 SparseSingleWorkspace(
size_t primary_chunkdim,
size_t secondary_chunkdim,
bool needs_value,
bool needs_index) : my_number(primary_chunkdim) {
28 size_t total_size = primary_chunkdim * secondary_chunkdim;
30 my_value_pool.resize(total_size);
31 my_values.reserve(primary_chunkdim);
32 auto vptr = my_value_pool.data();
33 for (
size_t p = 0; p < primary_chunkdim; ++p, vptr += secondary_chunkdim) {
34 my_values.push_back(vptr);
38 my_index_pool.resize(total_size);
39 my_indices.reserve(primary_chunkdim);
40 auto iptr = my_index_pool.data();
41 for (
size_t p = 0; p < primary_chunkdim; ++p, iptr += secondary_chunkdim) {
42 my_indices.push_back(iptr);
48 SparseSingleWorkspace(
const SparseSingleWorkspace&) =
delete;
49 SparseSingleWorkspace& operator=(
const SparseSingleWorkspace&) =
delete;
52 SparseSingleWorkspace(SparseSingleWorkspace&&) =
default;
53 SparseSingleWorkspace& operator=(SparseSingleWorkspace&&) =
default;
56 std::vector<CachedValue_> my_value_pool;
57 std::vector<Index_> my_index_pool;
58 std::vector<CachedValue_*> my_values;
59 std::vector<Index_*> my_indices;
60 std::vector<Index_> my_number;
63 std::vector<CachedValue_*>& get_values() {
67 std::vector<Index_*>& get_indices() {
71 std::vector<Index_>& get_number() {
80template<
typename Index_,
bool sparse_,
class Chunk_>
81class ChunkCoordinator {
83 ChunkCoordinator(ChunkDimensionStats<Index_> row_stats, ChunkDimensionStats<Index_> col_stats, std::vector<Chunk_> chunk_array,
bool row_major) :
84 my_row_stats(std::move(row_stats)), my_col_stats(std::move(col_stats)), my_chunk_array(std::move(chunk_array)), my_row_major(row_major)
86 if (
static_cast<size_t>(my_row_stats.num_chunks) *
static_cast<size_t>(my_col_stats.num_chunks) != my_chunk_array.size()) {
87 throw std::runtime_error(
"length of 'chunks' should be equal to the product of the number of chunks along each row and column");
92 ChunkDimensionStats<Index_> my_row_stats;
93 ChunkDimensionStats<Index_> my_col_stats;
94 std::vector<Chunk_> my_chunk_array;
100 Index_ get_num_chunks_per_row()
const {
101 return my_col_stats.num_chunks;
104 Index_ get_num_chunks_per_column()
const {
105 return my_row_stats.num_chunks;
108 Index_ get_nrow()
const {
109 return my_row_stats.dimension_extent;
112 Index_ get_ncol()
const {
113 return my_col_stats.dimension_extent;
116 bool prefer_rows_internal()
const {
118 return get_num_chunks_per_column() > get_num_chunks_per_row();
121 Index_ get_chunk_nrow()
const {
122 return my_row_stats.chunk_length;
125 Index_ get_chunk_ncol()
const {
126 return my_col_stats.chunk_length;
130 Index_ get_secondary_dim(
bool row)
const {
132 return my_col_stats.dimension_extent;
134 return my_row_stats.dimension_extent;
138 Index_ get_primary_chunkdim(
bool row)
const {
140 return my_row_stats.chunk_length;
142 return my_col_stats.chunk_length;
146 Index_ get_secondary_chunkdim(
bool row)
const {
148 return my_col_stats.chunk_length;
150 return my_row_stats.chunk_length;
155 Index_ get_primary_chunkdim(
bool row, Index_ chunk_id)
const {
160 std::pair<size_t, size_t> offset_and_increment(
bool row, Index_ chunk_id)
const {
161 size_t num_chunks = (my_row_major ? get_num_chunks_per_row() : get_num_chunks_per_column());
162 if (row == my_row_major) {
163 return std::pair<size_t, size_t>(
static_cast<size_t>(chunk_id) * num_chunks, 1);
165 return std::pair<size_t, size_t>(chunk_id, num_chunks);
169 template<
class ExtractFunction_>
170 void extract_secondary_block(
173 Index_ secondary_block_start,
174 Index_ secondary_block_length,
175 ExtractFunction_ extract)
177 auto secondary_chunkdim = get_secondary_chunkdim(row);
178 Index_ start_chunk_index = secondary_block_start / secondary_chunkdim;
179 Index_ secondary_start_pos = start_chunk_index * secondary_chunkdim;
180 Index_ secondary_block_end = secondary_block_start + secondary_block_length;
181 Index_ end_chunk_index = secondary_block_end / secondary_chunkdim + (secondary_block_end % secondary_chunkdim > 0);
183 auto oi = offset_and_increment(row, chunk_id);
184 auto offset = std::get<0>(oi);
185 auto increment = std::get<1>(oi);
186 offset += increment *
static_cast<size_t>(start_chunk_index);
188 for (Index_ c = start_chunk_index; c < end_chunk_index; ++c) {
189 const auto& chunk = my_chunk_array[offset];
190 Index_ from = (c == start_chunk_index ? secondary_block_start - secondary_start_pos : 0);
191 Index_ to = (c + 1 == end_chunk_index ? secondary_block_end - secondary_start_pos : secondary_chunkdim);
192 Index_ len = to - from;
196 if constexpr(sparse_) {
197 extract(chunk, from, len, secondary_start_pos);
199 extract(chunk, from, len);
202 secondary_start_pos += to;
207 template<
class ExtractFunction_>
208 void extract_secondary_index(
211 const std::vector<Index_>& secondary_indices,
212 std::vector<Index_>& chunk_indices_buffer,
213 ExtractFunction_ extract)
215 if (secondary_indices.empty()) {
219 auto secondary_chunkdim = get_secondary_chunkdim(row);
220 Index_ start_chunk_index = secondary_indices.front() / secondary_chunkdim;
221 Index_ secondary_start_pos = start_chunk_index * secondary_chunkdim;
223 auto oi = offset_and_increment(row, chunk_id);
224 auto offset = std::get<0>(oi);
225 auto increment = std::get<1>(oi);
226 offset += increment *
static_cast<size_t>(start_chunk_index);
228 auto secondary_dim = get_secondary_dim(row);
229 auto iIt = secondary_indices.begin();
230 auto iEnd = secondary_indices.end();
231 while (iIt != iEnd) {
232 const auto& chunk = my_chunk_array[offset];
234 Index_ secondary_end_pos = std::min(secondary_dim - secondary_start_pos, secondary_chunkdim) + secondary_start_pos;
235 chunk_indices_buffer.clear();
236 while (iIt != iEnd && *iIt < secondary_end_pos) {
237 chunk_indices_buffer.push_back(*iIt - secondary_start_pos);
241 if (!chunk_indices_buffer.empty()) {
242 if constexpr(sparse_) {
243 extract(chunk, chunk_indices_buffer, secondary_start_pos);
245 extract(chunk, chunk_indices_buffer);
249 secondary_start_pos = secondary_end_pos;
254 typedef typename Chunk_::Workspace ChunkWork;
255 typedef typename Chunk_::value_type ChunkValue;
256 typedef typename std::conditional<sparse_, typename SparseSlabFactory<ChunkValue, Index_>::Slab,
typename DenseSlabFactory<ChunkValue>::Slab>::type Slab;
257 typedef typename std::conditional<sparse_, SparseSingleWorkspace<ChunkValue, Index_>, DenseSingleWorkspace<ChunkValue> >::type SingleWorkspace;
271 std::pair<const Slab*, Index_> fetch_single(
274 Index_ secondary_block_start,
275 Index_ secondary_block_length,
276 ChunkWork& chunk_workspace,
277 SingleWorkspace& tmp_work,
280 Index_ primary_chunkdim = get_primary_chunkdim(row);
281 Index_ chunk_id = i / primary_chunkdim;
282 Index_ chunk_offset = i % primary_chunkdim;
284 if constexpr(sparse_) {
285 auto& final_num = *final_slab.number;
287 bool needs_value = !final_slab.values.empty();
288 bool needs_index = !final_slab.indices.empty();
290 extract_secondary_block(
291 row, chunk_id, secondary_block_start, secondary_block_length,
292 [&](
const Chunk_& chunk, Index_ from, Index_ len, Index_ secondary_start_pos) ->
void {
293 auto& tmp_values = tmp_work.get_values();
294 auto& tmp_indices = tmp_work.get_indices();
295 auto& tmp_number = tmp_work.get_number();
296 std::fill_n(tmp_number.begin(), primary_chunkdim, 0);
298 if constexpr(Chunk_::use_subset) {
299 chunk.extract(row, chunk_offset, 1, from, len, chunk_workspace, tmp_values, tmp_indices, tmp_number.data(), secondary_start_pos);
301 chunk.extract(row, from, len, chunk_workspace, tmp_values, tmp_indices, tmp_number.data(), secondary_start_pos);
304 auto count = tmp_number[chunk_offset];
306 std::copy_n(tmp_values[chunk_offset], count, final_slab.values[0] + final_num);
309 std::copy_n(tmp_indices[chunk_offset], count, final_slab.indices[0] + final_num);
316 auto final_slab_ptr = final_slab.data;
317 auto tmp_buffer_ptr = tmp_work.data();
319 extract_secondary_block(
320 row, chunk_id, secondary_block_start, secondary_block_length,
321 [&](
const Chunk_& chunk, Index_ from, Index_ len) ->
void {
322 if constexpr(Chunk_::use_subset) {
323 chunk.extract(row, chunk_offset, 1, from, len, chunk_workspace, tmp_buffer_ptr, len);
325 chunk.extract(row, from, len, chunk_workspace, tmp_buffer_ptr, len);
328 size_t tmp_offset =
static_cast<size_t>(len) *
static_cast<size_t>(chunk_offset);
329 std::copy_n(tmp_buffer_ptr + tmp_offset, len, final_slab_ptr);
330 final_slab_ptr += len;
335 return std::make_pair(&final_slab,
static_cast<Index_
>(0));
340 std::pair<const Slab*, Index_> fetch_single(
343 const std::vector<Index_>& secondary_indices,
344 std::vector<Index_>& chunk_indices_buffer,
345 ChunkWork& chunk_workspace,
346 SingleWorkspace& tmp_work,
349 Index_ primary_chunkdim = get_primary_chunkdim(row);
350 Index_ chunk_id = i / primary_chunkdim;
351 Index_ chunk_offset = i % primary_chunkdim;
353 if constexpr(sparse_) {
354 auto& final_num = *final_slab.number;
356 bool needs_value = !final_slab.values.empty();
357 bool needs_index = !final_slab.indices.empty();
359 extract_secondary_index(
360 row, chunk_id, secondary_indices, chunk_indices_buffer,
361 [&](
const Chunk_& chunk,
const std::vector<Index_>& chunk_indices, Index_ secondary_start_pos) ->
void {
362 auto& tmp_values = tmp_work.get_values();
363 auto& tmp_indices = tmp_work.get_indices();
364 auto& tmp_number = tmp_work.get_number();
365 std::fill_n(tmp_number.begin(), primary_chunkdim, 0);
367 if constexpr(Chunk_::use_subset) {
368 chunk.extract(row, chunk_offset, 1, chunk_indices, chunk_workspace, tmp_values, tmp_indices, tmp_number.data(), secondary_start_pos);
370 chunk.extract(row, chunk_indices, chunk_workspace, tmp_values, tmp_indices, tmp_number.data(), secondary_start_pos);
373 auto count = tmp_number[chunk_offset];
375 std::copy_n(tmp_values[chunk_offset], count, final_slab.values[0] + final_num);
378 std::copy_n(tmp_indices[chunk_offset], count, final_slab.indices[0] + final_num);
385 auto final_slab_ptr = final_slab.data;
386 auto tmp_buffer_ptr = tmp_work.data();
388 extract_secondary_index(
389 row, chunk_id, secondary_indices, chunk_indices_buffer,
390 [&](
const Chunk_& chunk,
const std::vector<Index_>& chunk_indices) ->
void {
391 size_t nidx = chunk_indices.size();
392 if constexpr(Chunk_::use_subset) {
393 chunk.extract(row, chunk_offset, 1, chunk_indices, chunk_workspace, tmp_buffer_ptr, nidx);
395 chunk.extract(row, chunk_indices, chunk_workspace, tmp_buffer_ptr, nidx);
398 size_t tmp_offset = nidx *
static_cast<size_t>(chunk_offset);
399 std::copy_n(tmp_buffer_ptr + tmp_offset, nidx, final_slab_ptr);
400 final_slab_ptr += nidx;
405 return std::make_pair(&final_slab,
static_cast<Index_
>(0));
415 Index_ secondary_block_start,
416 Index_ secondary_block_length,
418 ChunkWork& chunk_workspace)
420 if constexpr(sparse_) {
421 std::fill_n(slab.number, get_primary_chunkdim(row), 0);
423 extract_secondary_block(
424 row, chunk_id, secondary_block_start, secondary_block_length,
425 [&](
const Chunk_& chunk, Index_ from, Index_ len, Index_ secondary_start_pos) ->
void {
426 if constexpr(Chunk_::use_subset) {
427 chunk.extract(row, chunk_offset, chunk_length, from, len, chunk_workspace, slab.values, slab.indices, slab.number, secondary_start_pos);
429 chunk.extract(row, from, len, chunk_workspace, slab.values, slab.indices, slab.number, secondary_start_pos);
435 auto slab_ptr = slab.data;
436 size_t stride = secondary_block_length;
438 extract_secondary_block(
439 row, chunk_id, secondary_block_start, secondary_block_length,
440 [&](
const Chunk_& chunk, Index_ from, Index_ len) ->
void {
441 if constexpr(Chunk_::use_subset) {
442 chunk.extract(row, chunk_offset, chunk_length, from, len, chunk_workspace, slab_ptr, stride);
444 chunk.extract(row, from, len, chunk_workspace, slab_ptr, stride);
458 const std::vector<Index_>& secondary_indices,
459 std::vector<Index_>& chunk_indices_buffer,
461 ChunkWork& chunk_workspace)
463 if constexpr(sparse_) {
464 std::fill_n(slab.number, get_primary_chunkdim(row), 0);
466 extract_secondary_index(
467 row, chunk_id, secondary_indices, chunk_indices_buffer,
468 [&](
const Chunk_& chunk,
const std::vector<Index_>& chunk_indices, Index_ secondary_start_pos) ->
void {
469 if constexpr(Chunk_::use_subset) {
470 chunk.extract(row, chunk_offset, chunk_length, chunk_indices, chunk_workspace, slab.values, slab.indices, slab.number, secondary_start_pos);
472 chunk.extract(row, chunk_indices, chunk_workspace, slab.values, slab.indices, slab.number, secondary_start_pos);
478 auto slab_ptr = slab.data;
479 size_t stride = secondary_indices.size();
481 extract_secondary_index(
482 row, chunk_id, secondary_indices, chunk_indices_buffer,
483 [&](
const Chunk_& chunk,
const std::vector<Index_>& chunk_indices) ->
void {
484 if constexpr(Chunk_::use_subset) {
485 chunk.extract(row, chunk_offset, chunk_length, chunk_indices, chunk_workspace, slab_ptr, stride);
487 chunk.extract(row, chunk_indices, chunk_workspace, slab_ptr, stride);
489 slab_ptr += chunk_indices.size();
500 const std::vector<Index_>& primary_indices,
501 Index_ secondary_block_start,
502 Index_ secondary_block_length,
504 ChunkWork& chunk_workspace)
506 if constexpr(sparse_) {
507 std::fill_n(slab.number, get_primary_chunkdim(row), 0);
509 extract_secondary_block(
510 row, chunk_id, secondary_block_start, secondary_block_length,
511 [&](
const Chunk_& chunk, Index_ from, Index_ len, Index_ secondary_start_pos) ->
void {
512 if constexpr(Chunk_::use_subset) {
513 chunk.extract(row, primary_indices, from, len, chunk_workspace, slab.values, slab.indices, slab.number, secondary_start_pos);
515 chunk.extract(row, from, len, chunk_workspace, slab.values, slab.indices, slab.number, secondary_start_pos);
521 auto slab_ptr = slab.data;
522 size_t stride = secondary_block_length;
524 extract_secondary_block(
525 row, chunk_id, secondary_block_start, secondary_block_length,
526 [&](
const Chunk_& chunk, Index_ from, Index_ len) ->
void {
527 if constexpr(Chunk_::use_subset) {
528 chunk.extract(row, primary_indices, from, len, chunk_workspace, slab_ptr, stride);
530 chunk.extract(row, from, len, chunk_workspace, slab_ptr, stride);
542 const std::vector<Index_>& primary_indices,
543 const std::vector<Index_>& secondary_indices,
544 std::vector<Index_>& chunk_indices_buffer,
546 ChunkWork& chunk_workspace)
548 if constexpr(sparse_) {
549 std::fill_n(slab.number, get_primary_chunkdim(row), 0);
551 extract_secondary_index(
552 row, chunk_id, secondary_indices, chunk_indices_buffer,
553 [&](
const Chunk_& chunk,
const std::vector<Index_>& chunk_indices, Index_ secondary_start_pos) ->
void {
554 if constexpr(Chunk_::use_subset) {
555 chunk.extract(row, primary_indices, chunk_indices, chunk_workspace, slab.values, slab.indices, slab.number, secondary_start_pos);
557 chunk.extract(row, chunk_indices, chunk_workspace, slab.values, slab.indices, slab.number, secondary_start_pos);
563 auto slab_ptr = slab.data;
564 size_t stride = secondary_indices.size();
566 extract_secondary_index(
567 row, chunk_id, secondary_indices, chunk_indices_buffer,
568 [&](
const Chunk_& chunk,
const std::vector<Index_>& chunk_indices) ->
void {
569 if constexpr(Chunk_::use_subset) {
570 chunk.extract(row, primary_indices, chunk_indices, chunk_workspace, slab_ptr, stride);
572 chunk.extract(row, chunk_indices, chunk_workspace, slab_ptr, stride);
574 slab_ptr += chunk_indices.size();
582 template<
class Cache_,
class Factory_>
583 std::pair<const Slab*, Index_> fetch_myopic(
588 ChunkWork& chunk_workspace,
592 Index_ primary_chunkdim = get_primary_chunkdim(row);
593 Index_ chunk_id = i / primary_chunkdim;
594 Index_ chunk_offset = i % primary_chunkdim;
595 auto& out = cache.find(
598 return factory.create();
600 [&](Index_ id, Slab& slab) ->
void {
601 fetch_block(row,
id, 0, get_primary_chunkdim(row,
id), block_start, block_length, slab, chunk_workspace);
604 return std::make_pair(&out, chunk_offset);
607 template<
class Cache_,
class Factory_>
608 std::pair<const Slab*, Index_> fetch_myopic(
611 const std::vector<Index_>& indices,
612 std::vector<Index_>& tmp_indices,
613 ChunkWork& chunk_workspace,
617 Index_ primary_chunkdim = get_primary_chunkdim(row);
618 Index_ chunk_id = i / primary_chunkdim;
619 Index_ chunk_offset = i % primary_chunkdim;
620 auto& out = cache.find(
623 return factory.create();
625 [&](Index_ id, Slab& slab) ->
void {
626 fetch_block(row,
id, 0, get_primary_chunkdim(row,
id), indices, tmp_indices, slab, chunk_workspace);
629 return std::make_pair(&out, chunk_offset);
633 template<
class Cache_,
class Factory_,
typename PopulateBlock_,
typename PopulateIndex_>
634 std::pair<const Slab*, Index_> fetch_oracular_core(
638 PopulateBlock_ populate_block,
639 PopulateIndex_ populate_index)
641 Index_ primary_chunkdim = get_primary_chunkdim(row);
642 if constexpr(Chunk_::use_subset) {
644 [&](Index_ i) -> std::pair<Index_, Index_> {
645 return std::pair<Index_, Index_>(i / primary_chunkdim, i % primary_chunkdim);
648 return factory.create();
650 [&](std::vector<std::tuple<Index_, Slab*, const OracularSubsettedSlabCacheSelectionDetails<Index_>*> >& in_need) ->
void {
651 for (
const auto& p : in_need) {
652 auto id = std::get<0>(p);
653 auto ptr = std::get<1>(p);
654 auto sub = std::get<2>(p);
655 switch (sub->selection) {
656 case OracularSubsettedSlabCacheSelectionType::FULL:
657 populate_block(
id, 0, get_primary_chunkdim(row,
id), *ptr);
659 case OracularSubsettedSlabCacheSelectionType::BLOCK:
660 populate_block(
id, sub->block_start, sub->block_length, *ptr);
662 case OracularSubsettedSlabCacheSelectionType::INDEX:
663 populate_index(
id, sub->indices, *ptr);
672 [&](Index_ i) -> std::pair<Index_, Index_> {
673 return std::pair<Index_, Index_>(i / primary_chunkdim, i % primary_chunkdim);
676 return factory.create();
678 [&](std::vector<std::pair<Index_, Slab*> >& to_populate) ->
void {
679 for (
auto& p : to_populate) {
680 populate_block(p.first, 0, get_primary_chunkdim(row, p.first), *(p.second));
688 template<
class Cache_,
class Factory_>
689 std::pair<const Slab*, Index_> fetch_oracular(
693 ChunkWork& chunk_workspace,
697 return fetch_oracular_core(
701 [&](Index_ pid, Index_ pstart, Index_ plen, Slab& slab) ->
void {
702 fetch_block(row, pid, pstart, plen, block_start, block_length, slab, chunk_workspace);
704 [&](Index_ pid,
const std::vector<Index_>& pindices, Slab& slab) ->
void {
705 fetch_index(row, pid, pindices, block_start, block_length, slab, chunk_workspace);
710 template<
class Cache_,
class Factory_>
711 std::pair<const Slab*, Index_> fetch_oracular(
713 const std::vector<Index_>& indices,
714 std::vector<Index_>& chunk_indices_buffer,
715 ChunkWork& chunk_workspace,
719 return fetch_oracular_core(
723 [&](Index_ pid, Index_ pstart, Index_ plen, Slab& slab) ->
void {
724 fetch_block(row, pid, pstart, plen, indices, chunk_indices_buffer, slab, chunk_workspace);
726 [&](Index_ pid,
const std::vector<Index_>& pindices, Slab& slab) ->
void {
727 fetch_index(row, pid, pindices, indices, chunk_indices_buffer, slab, chunk_workspace);
Create a oracle-aware cache for slabs.
Create a oracle-aware cache with subsets.
Factory for sparse slabs.
Methods to handle chunked tatami matrices.
Definition ChunkDimensionStats.hpp:4
Index_ get_chunk_length(const ChunkDimensionStats< Index_ > &stats, Index_ i)
Definition ChunkDimensionStats.hpp:85