1#ifndef TATAMI_CHUNKED_CUSTOM_CHUNK_COORDINATOR_HPP
2#define TATAMI_CHUNKED_CUSTOM_CHUNK_COORDINATOR_HPP
9#include "ChunkDimensionStats.hpp"
16#include "sanisizer/sanisizer.hpp"
20namespace CustomChunkedMatrix_internal {
26template<
typename ChunkValue_>
27using DenseSingleWorkspace = std::vector<ChunkValue_>;
29template<
typename ChunkValue_,
typename Index_>
30class SparseSingleWorkspace {
32 SparseSingleWorkspace(Index_ target_chunkdim, Index_ non_target_chunkdim,
bool needs_value,
bool needs_index) : my_number(target_chunkdim) {
34 my_value_pool.resize(sanisizer::product<
decltype(my_value_pool.size())>(target_chunkdim, non_target_chunkdim));
35 my_values.reserve(target_chunkdim);
36 auto vptr = my_value_pool.data();
37 for (Index_ p = 0; p < target_chunkdim; ++p, vptr += non_target_chunkdim) {
38 my_values.push_back(vptr);
42 my_index_pool.resize(sanisizer::product<
decltype(my_index_pool.size())>(target_chunkdim, non_target_chunkdim));
43 my_indices.reserve(target_chunkdim);
44 auto iptr = my_index_pool.data();
45 for (Index_ p = 0; p < target_chunkdim; ++p, iptr += non_target_chunkdim) {
46 my_indices.push_back(iptr);
52 SparseSingleWorkspace(
const SparseSingleWorkspace&) =
delete;
53 SparseSingleWorkspace& operator=(
const SparseSingleWorkspace&) =
delete;
56 SparseSingleWorkspace(SparseSingleWorkspace&&) =
default;
57 SparseSingleWorkspace& operator=(SparseSingleWorkspace&&) =
default;
60 std::vector<ChunkValue_> my_value_pool;
61 std::vector<Index_> my_index_pool;
62 std::vector<ChunkValue_*> my_values;
63 std::vector<Index_*> my_indices;
64 std::vector<Index_> my_number;
67 std::vector<ChunkValue_*>& get_values() {
71 std::vector<Index_*>& get_indices() {
75 std::vector<Index_>& get_number() {
84template<
bool sparse_,
class ChunkValue_,
typename Index_>
85class ChunkCoordinator {
87 ChunkCoordinator(ChunkDimensionStats<Index_> row_stats, ChunkDimensionStats<Index_> col_stats) :
88 my_row_stats(std::move(row_stats)),
89 my_col_stats(std::move(col_stats))
93 ChunkDimensionStats<Index_> my_row_stats;
94 ChunkDimensionStats<Index_> my_col_stats;
99 Index_ get_num_chunks_per_row()
const {
100 return my_col_stats.num_chunks;
103 Index_ get_num_chunks_per_column()
const {
104 return my_row_stats.num_chunks;
107 Index_ get_nrow()
const {
108 return my_row_stats.dimension_extent;
111 Index_ get_ncol()
const {
112 return my_col_stats.dimension_extent;
115 Index_ get_chunk_nrow()
const {
116 return my_row_stats.chunk_length;
119 Index_ get_chunk_ncol()
const {
120 return my_col_stats.chunk_length;
124 Index_ get_non_target_dim(
bool row)
const {
126 return my_col_stats.dimension_extent;
128 return my_row_stats.dimension_extent;
132 Index_ get_target_chunkdim(
bool row)
const {
134 return my_row_stats.chunk_length;
136 return my_col_stats.chunk_length;
140 Index_ get_non_target_chunkdim(
bool row)
const {
142 return my_col_stats.chunk_length;
144 return my_row_stats.chunk_length;
149 Index_ get_target_chunkdim(
bool row, Index_ chunk_id)
const {
154 template<
class ExtractFunction_>
155 void extract_non_target_block(
157 Index_ target_chunk_id,
158 Index_ non_target_block_start,
159 Index_ non_target_block_length,
160 ExtractFunction_ extract)
162 auto non_target_chunkdim = get_non_target_chunkdim(row);
163 Index_ non_target_start_chunk_index = non_target_block_start / non_target_chunkdim;
164 Index_ non_target_start_pos = non_target_start_chunk_index * non_target_chunkdim;
165 Index_ non_target_block_end = non_target_block_start + non_target_block_length;
166 Index_ non_target_end_chunk_index = non_target_block_end / non_target_chunkdim + (non_target_block_end % non_target_chunkdim > 0);
168 for (Index_ non_target_chunk_id = non_target_start_chunk_index; non_target_chunk_id < non_target_end_chunk_index; ++non_target_chunk_id) {
169 Index_ from = (non_target_chunk_id == non_target_start_chunk_index ? non_target_block_start - non_target_start_pos : 0);
170 Index_ to = (non_target_chunk_id + 1 == non_target_end_chunk_index ? non_target_block_end - non_target_start_pos : non_target_chunkdim);
171 Index_ len = to - from;
173 auto row_id = (row ? target_chunk_id : non_target_chunk_id);
174 auto col_id = (row ? non_target_chunk_id : target_chunk_id);
178 if constexpr(sparse_) {
179 extract(row_id, col_id, from, len, non_target_start_pos);
181 extract(row_id, col_id, from, len);
185 non_target_start_pos += to;
189 template<
class ExtractFunction_>
190 void extract_non_target_index(
192 Index_ target_chunk_id,
193 const std::vector<Index_>& non_target_indices,
194 std::vector<Index_>& chunk_indices_buffer,
195 ExtractFunction_ extract)
197 auto non_target_chunkdim = get_non_target_chunkdim(row);
198 auto non_target_dim = get_non_target_dim(row);
199 auto iIt = non_target_indices.begin();
200 auto iEnd = non_target_indices.end();
202 while (iIt != iEnd) {
203 Index_ non_target_chunk_id = *iIt / non_target_chunkdim;
204 Index_ non_target_start_pos = non_target_chunk_id * non_target_chunkdim;
205 Index_ non_target_end_pos = std::min(non_target_dim - non_target_start_pos, non_target_chunkdim) + non_target_start_pos;
207 chunk_indices_buffer.clear();
209 chunk_indices_buffer.push_back(*iIt - non_target_start_pos);
211 }
while (iIt != iEnd && *iIt < non_target_end_pos);
213 auto row_id = (row ? target_chunk_id : non_target_chunk_id);
214 auto col_id = (row ? non_target_chunk_id : target_chunk_id);
215 if constexpr(sparse_) {
216 extract(row_id, col_id, chunk_indices_buffer, non_target_start_pos);
218 extract(row_id, col_id, chunk_indices_buffer);
223 typedef typename std::conditional<sparse_, typename SparseSlabFactory<ChunkValue_, Index_>::Slab,
typename DenseSlabFactory<ChunkValue_>::Slab>::type Slab;
224 typedef typename std::conditional<sparse_, SparseSingleWorkspace<ChunkValue_, Index_>, DenseSingleWorkspace<ChunkValue_> >::type SingleWorkspace;
229 template<
class ChunkWorkspace_>
230 std::pair<const Slab*, Index_> fetch_single(
233 Index_ non_target_block_start,
234 Index_ non_target_block_length,
235 ChunkWorkspace_& chunk_workspace,
236 SingleWorkspace& tmp_work,
239 Index_ target_chunkdim = get_target_chunkdim(row);
240 Index_ target_chunk_id = i / target_chunkdim;
241 Index_ target_chunk_offset = i % target_chunkdim;
243 if constexpr(sparse_) {
244 auto& final_num = *final_slab.number;
246 bool needs_value = !final_slab.values.empty();
247 bool needs_index = !final_slab.indices.empty();
249 extract_non_target_block(
252 non_target_block_start,
253 non_target_block_length,
254 [&](Index_ row_id, Index_ column_id, Index_ from, Index_ len, Index_ non_target_start_pos) ->
void {
255 auto& tmp_values = tmp_work.get_values();
256 auto& tmp_indices = tmp_work.get_indices();
257 auto& tmp_number = tmp_work.get_number();
259 tmp_number[target_chunk_offset] = 0;
260 chunk_workspace.extract(
265 static_cast<Index_
>(1),
274 auto count = tmp_number[target_chunk_offset];
276 std::copy_n(tmp_values[target_chunk_offset], count, final_slab.values[0] + final_num);
279 std::copy_n(tmp_indices[target_chunk_offset], count, final_slab.indices[0] + final_num);
286 auto final_slab_ptr = final_slab.data;
287 auto tmp_buffer_ptr = tmp_work.data();
288 typedef decltype(tmp_work.size()) Size;
290 extract_non_target_block(
293 non_target_block_start,
294 non_target_block_length,
295 [&](Index_ row_id, Index_ column_id, Index_ from, Index_ len) ->
void {
297 chunk_workspace.extract(
302 static_cast<Index_
>(1),
309 Size tmp_offset = sanisizer::product_unsafe<Size>(len, target_chunk_offset);
310 std::copy_n(tmp_buffer_ptr + tmp_offset, len, final_slab_ptr);
311 final_slab_ptr += len;
316 return std::make_pair(&final_slab,
static_cast<Index_
>(0));
321 template<
class ChunkWorkspace_>
322 std::pair<const Slab*, Index_> fetch_single(
325 const std::vector<Index_>& non_target_indices,
326 std::vector<Index_>& chunk_indices_buffer,
327 ChunkWorkspace_& chunk_workspace,
328 SingleWorkspace& tmp_work,
331 Index_ target_chunkdim = get_target_chunkdim(row);
332 Index_ target_chunk_id = i / target_chunkdim;
333 Index_ target_chunk_offset = i % target_chunkdim;
335 if constexpr(sparse_) {
336 auto& final_num = *final_slab.number;
338 bool needs_value = !final_slab.values.empty();
339 bool needs_index = !final_slab.indices.empty();
341 extract_non_target_index(
345 chunk_indices_buffer,
346 [&](Index_ row_id, Index_ column_id,
const std::vector<Index_>& chunk_indices, Index_ non_target_start_pos) ->
void {
347 auto& tmp_values = tmp_work.get_values();
348 auto& tmp_indices = tmp_work.get_indices();
349 auto& tmp_number = tmp_work.get_number();
351 tmp_number[target_chunk_offset] = 0;
352 chunk_workspace.extract(
357 static_cast<Index_
>(1),
365 auto count = tmp_number[target_chunk_offset];
367 std::copy_n(tmp_values[target_chunk_offset], count, final_slab.values[0] + final_num);
370 std::copy_n(tmp_indices[target_chunk_offset], count, final_slab.indices[0] + final_num);
377 auto final_slab_ptr = final_slab.data;
378 auto tmp_buffer_ptr = tmp_work.data();
379 typedef decltype(tmp_work.size()) Size;
381 extract_non_target_index(
385 chunk_indices_buffer,
386 [&](Index_ row_id, Index_ column_id, const std::vector<Index_>& chunk_indices) ->
void {
387 auto nidx = chunk_indices.size();
388 chunk_workspace.extract(
393 static_cast<Index_
>(1),
399 Size tmp_offset =
static_cast<Size
>(nidx) *
static_cast<Size
>(target_chunk_offset);
400 std::copy_n(tmp_buffer_ptr + tmp_offset, nidx, final_slab_ptr);
401 final_slab_ptr += nidx;
406 return std::make_pair(&final_slab,
static_cast<Index_
>(0));
411 template<
class ChunkWorkspace_>
414 Index_ target_chunk_id,
415 Index_ target_chunk_offset,
416 Index_ target_chunk_length,
417 Index_ non_target_block_start,
418 Index_ non_target_block_length,
420 ChunkWorkspace_& chunk_workspace)
422 if constexpr(sparse_) {
423 std::fill_n(slab.number, get_target_chunkdim(row), 0);
425 extract_non_target_block(
428 non_target_block_start,
429 non_target_block_length,
430 [&](Index_ row_id, Index_ column_id, Index_ from, Index_ len, Index_ non_target_start_pos) ->
void {
431 chunk_workspace.extract(
448 auto slab_ptr = slab.data;
450 extract_non_target_block(
453 non_target_block_start,
454 non_target_block_length,
455 [&](Index_ row_id, Index_ column_id, Index_ from, Index_ len) ->
void {
456 chunk_workspace.extract(
465 non_target_block_length
474 template<
class ChunkWorkspace_>
477 Index_ target_chunk_id,
478 Index_ target_chunk_offset,
479 Index_ target_chunk_length,
480 const std::vector<Index_>& non_target_indices,
481 std::vector<Index_>& chunk_indices_buffer,
483 ChunkWorkspace_& chunk_workspace)
485 if constexpr(sparse_) {
486 std::fill_n(slab.number, get_target_chunkdim(row), 0);
488 extract_non_target_index(
492 chunk_indices_buffer,
493 [&](Index_ row_id, Index_ column_id,
const std::vector<Index_>& chunk_indices, Index_ non_target_start_pos) ->
void {
494 chunk_workspace.extract(
510 auto slab_ptr = slab.data;
511 Index_ stride = non_target_indices.size();
512 extract_non_target_index(
516 chunk_indices_buffer,
517 [&](Index_ row_id, Index_ column_id,
const std::vector<Index_>& chunk_indices) ->
void {
518 chunk_workspace.extract(
528 slab_ptr += chunk_indices.size();
536 template<
class ChunkWorkspace_>
539 Index_ target_chunk_id,
540 const std::vector<Index_>& target_indices,
541 Index_ non_target_block_start,
542 Index_ non_target_block_length,
544 ChunkWorkspace_& chunk_workspace)
546 if constexpr(sparse_) {
547 std::fill_n(slab.number, get_target_chunkdim(row), 0);
548 extract_non_target_block(
551 non_target_block_start,
552 non_target_block_length,
553 [&](Index_ row_id, Index_ column_id, Index_ from, Index_ len, Index_ non_target_start_pos) ->
void {
554 chunk_workspace.extract(
570 auto slab_ptr = slab.data;
571 extract_non_target_block(
574 non_target_block_start,
575 non_target_block_length,
576 [&](Index_ row_id, Index_ column_id, Index_ from, Index_ len) ->
void {
577 chunk_workspace.extract(
585 non_target_block_length
594 template<
class ChunkWorkspace_>
597 Index_ target_chunk_id,
598 const std::vector<Index_>& target_indices,
599 const std::vector<Index_>& non_target_indices,
600 std::vector<Index_>& chunk_indices_buffer,
602 ChunkWorkspace_& chunk_workspace)
604 if constexpr(sparse_) {
605 std::fill_n(slab.number, get_target_chunkdim(row), 0);
606 extract_non_target_index(
610 chunk_indices_buffer,
611 [&](Index_ row_id, Index_ column_id,
const std::vector<Index_>& chunk_indices, Index_ non_target_start_pos) ->
void {
612 chunk_workspace.extract(
627 auto slab_ptr = slab.data;
628 Index_ stride = non_target_indices.size();
629 extract_non_target_index(
633 chunk_indices_buffer,
634 [&](Index_ row_id, Index_ column_id,
const std::vector<Index_>& chunk_indices) ->
void {
635 chunk_workspace.extract(
644 slab_ptr += chunk_indices.size();
652 template<
class ChunkWorkspace_,
class Cache_,
class Factory_>
653 std::pair<const Slab*, Index_> fetch_myopic(
658 ChunkWorkspace_& chunk_workspace,
662 Index_ target_chunkdim = get_target_chunkdim(row);
663 Index_ target_chunk_id = i / target_chunkdim;
664 Index_ target_chunk_offset = i % target_chunkdim;
665 auto& out = cache.find(
668 return factory.create();
670 [&](Index_ id, Slab& slab) ->
void {
671 fetch_block(row,
id, 0, get_target_chunkdim(row,
id), block_start, block_length, slab, chunk_workspace);
674 return std::make_pair(&out, target_chunk_offset);
677 template<
class ChunkWorkspace_,
class Cache_,
class Factory_>
678 std::pair<const Slab*, Index_> fetch_myopic(
681 const std::vector<Index_>& indices,
682 std::vector<Index_>& tmp_indices,
683 ChunkWorkspace_& chunk_workspace,
687 Index_ target_chunkdim = get_target_chunkdim(row);
688 Index_ target_chunk_id = i / target_chunkdim;
689 Index_ target_chunk_offset = i % target_chunkdim;
690 auto& out = cache.find(
693 return factory.create();
695 [&](Index_ id, Slab& slab) ->
void {
696 fetch_block(row,
id, 0, get_target_chunkdim(row,
id), indices, tmp_indices, slab, chunk_workspace);
699 return std::make_pair(&out, target_chunk_offset);
703 template<
class ChunkWorkspace_,
class Cache_,
class Factory_>
704 std::pair<const Slab*, Index_> fetch_oracular(
708 ChunkWorkspace_& chunk_workspace,
712 Index_ target_chunkdim = get_target_chunkdim(row);
714 [&](Index_ i) -> std::pair<Index_, Index_> {
715 return std::pair<Index_, Index_>(i / target_chunkdim, i % target_chunkdim);
718 return factory.create();
720 [&](std::vector<std::pair<Index_, Slab*> >& to_populate) ->
void {
721 for (
auto& p : to_populate) {
722 fetch_block(row, p.first, 0, get_target_chunkdim(row, p.first), block_start, block_length, *(p.second), chunk_workspace);
728 template<
class ChunkWorkspace_,
class Cache_,
class Factory_>
729 std::pair<const Slab*, Index_> fetch_oracular(
731 const std::vector<Index_>& indices,
732 std::vector<Index_>& chunk_indices_buffer,
733 ChunkWorkspace_& chunk_workspace,
737 Index_ target_chunkdim = get_target_chunkdim(row);
739 [&](Index_ i) -> std::pair<Index_, Index_> {
740 return std::pair<Index_, Index_>(i / target_chunkdim, i % target_chunkdim);
743 return factory.create();
745 [&](std::vector<std::pair<Index_, Slab*> >& to_populate) ->
void {
746 for (
auto& p : to_populate) {
747 fetch_block(row, p.first, 0, get_target_chunkdim(row, p.first), indices, chunk_indices_buffer, *(p.second), chunk_workspace);
754 template<
class ChunkWorkspace_,
class Cache_,
class Factory_>
755 std::pair<const Slab*, Index_> fetch_oracular_subsetted(
759 ChunkWorkspace_& chunk_workspace,
763 Index_ target_chunkdim = get_target_chunkdim(row);
765 [&](Index_ i) -> std::pair<Index_, Index_> {
766 return std::pair<Index_, Index_>(i / target_chunkdim, i % target_chunkdim);
769 return factory.create();
771 [&](std::vector<std::tuple<Index_, Slab*, const OracularSubsettedSlabCacheSelectionDetails<Index_>*> >& in_need) ->
void {
772 for (
const auto& p : in_need) {
773 auto id = std::get<0>(p);
774 auto ptr = std::get<1>(p);
775 auto sub = std::get<2>(p);
776 switch (sub->selection) {
777 case OracularSubsettedSlabCacheSelectionType::FULL:
778 fetch_block(row,
id, 0, get_target_chunkdim(row,
id), block_start, block_length, *ptr, chunk_workspace);
780 case OracularSubsettedSlabCacheSelectionType::BLOCK:
781 fetch_block(row,
id, sub->block_start, sub->block_length, block_start, block_length, *ptr, chunk_workspace);
783 case OracularSubsettedSlabCacheSelectionType::INDEX:
784 fetch_index(row,
id, sub->indices, block_start, block_length, *ptr, chunk_workspace);
792 template<
class ChunkWorkspace_,
class Cache_,
class Factory_>
793 std::pair<const Slab*, Index_> fetch_oracular_subsetted(
795 const std::vector<Index_>& indices,
796 std::vector<Index_>& chunk_indices_buffer,
797 ChunkWorkspace_& chunk_workspace,
801 Index_ target_chunkdim = get_target_chunkdim(row);
803 [&](Index_ i) -> std::pair<Index_, Index_> {
804 return std::pair<Index_, Index_>(i / target_chunkdim, i % target_chunkdim);
807 return factory.create();
809 [&](std::vector<std::tuple<Index_, Slab*, const OracularSubsettedSlabCacheSelectionDetails<Index_>*> >& in_need) ->
void {
810 for (
const auto& p : in_need) {
811 auto id = std::get<0>(p);
812 auto ptr = std::get<1>(p);
813 auto sub = std::get<2>(p);
814 switch (sub->selection) {
815 case OracularSubsettedSlabCacheSelectionType::FULL:
816 fetch_block(row,
id, 0, get_target_chunkdim(row,
id), indices, chunk_indices_buffer, *ptr, chunk_workspace);
818 case OracularSubsettedSlabCacheSelectionType::BLOCK:
819 fetch_block(row,
id, sub->block_start, sub->block_length, indices, chunk_indices_buffer, *ptr, chunk_workspace);
821 case OracularSubsettedSlabCacheSelectionType::INDEX:
822 fetch_index(row,
id, sub->indices, indices, chunk_indices_buffer, *ptr, chunk_workspace);
Create a oracle-aware cache for slabs.
Create a oracle-aware cache with subsets.
Factory for sparse slabs.
Methods to handle chunked tatami matrices.
Definition ChunkDimensionStats.hpp:4
Index_ get_chunk_length(const ChunkDimensionStats< Index_ > &stats, Index_ i)
Definition ChunkDimensionStats.hpp:85