tatami_mult
Multiply tatami matrices
Loading...
Searching...
No Matches
tatami_mult.hpp
Go to the documentation of this file.
1#ifndef TATAMI_MULT_HPP
2#define TATAMI_MULT_HPP
3
4#include "dense_row.hpp"
5#include "sparse_row.hpp"
6#include "dense_column.hpp"
7#include "sparse_column.hpp"
8
9#include <vector>
10
11#include "tatami/tatami.hpp"
12#include "sanisizer/sanisizer.hpp"
13
23namespace tatami_mult {
24
28struct Options {
32 int num_threads = 1;
33
41 bool prefer_larger = true;
42
49};
50
66template<typename Value_, typename Index_, typename Right_, typename Output_>
67void multiply(const tatami::Matrix<Value_, Index_>& left, const Right_* right, Output_* output, const Options& opt) {
68 if (left.sparse()) {
69 if (left.prefer_rows()) {
70 internal::sparse_row_vector(left, right, output, opt.num_threads);
71 } else {
72 internal::sparse_column_vector(left, right, output, opt.num_threads);
73 }
74 } else {
75 if (left.prefer_rows()) {
76 internal::dense_row_vector(left, right, output, opt.num_threads);
77 } else {
78 internal::dense_column_vector(left, right, output, opt.num_threads);
79 }
80 }
81}
82
98template<typename Left_, typename Value_, typename Index_, typename Output_>
99void multiply(const Left_* left, const tatami::Matrix<Value_, Index_>& right, Output_* output, const Options& opt) {
100 auto tright = tatami::make_DelayedTranspose(tatami::wrap_shared_ptr(&right));
101 if (tright->sparse()) {
102 if (tright->prefer_rows()) {
103 internal::sparse_row_vector(*tright, left, output, opt.num_threads);
104 } else {
105 internal::sparse_column_vector(*tright, left, output, opt.num_threads);
106 }
107 } else {
108 if (tright->prefer_rows()) {
109 internal::dense_row_vector(*tright, left, output, opt.num_threads);
110 } else {
111 internal::dense_column_vector(*tright, left, output, opt.num_threads);
112 }
113 }
114}
115
131template<typename Value_, typename Index_, typename Right_, typename Output_>
132void multiply(const tatami::Matrix<Value_, Index_>& left, const std::vector<Right_*>& right, const std::vector<Output_*>& output, const Options& opt) {
133 if (left.sparse()) {
134 if (left.prefer_rows()) {
135 internal::sparse_row_vectors(left, right, output, opt.num_threads);
136 } else {
137 internal::sparse_column_vectors(left, right, output, opt.num_threads);
138 }
139 } else {
140 if (left.prefer_rows()) {
141 internal::dense_row_vectors(left, right, output, opt.num_threads);
142 } else {
143 internal::dense_column_vectors(left, right, output, opt.num_threads);
144 }
145 }
146}
147
163template<typename Left_, typename Value_, typename Index_, typename Output_>
164void multiply(const std::vector<Left_*>& left, const tatami::Matrix<Value_, Index_>& right, const std::vector<Output_*>& output, const Options& opt) {
165 auto tright = tatami::make_DelayedTranspose(tatami::wrap_shared_ptr(&right));
166 if (tright->sparse()) {
167 if (tright->prefer_rows()) {
168 internal::sparse_row_vectors(*tright, left, output, opt.num_threads);
169 } else {
170 internal::sparse_column_vectors(*tright, left, output, opt.num_threads);
171 }
172 } else {
173 if (tright->prefer_rows()) {
174 internal::dense_row_vectors(*tright, left, output, opt.num_threads);
175 } else {
176 internal::dense_column_vectors(*tright, left, output, opt.num_threads);
177 }
178 }
179}
180
184namespace internal {
185
186template<typename LeftValue_, typename LeftIndex_, typename RightValue_, typename RightIndex_, typename Output_>
187void multiply(const tatami::Matrix<LeftValue_, LeftIndex_>& left, const tatami::Matrix<RightValue_, RightIndex_>& right, Output_* output, bool column_major_out, int num_threads) {
188 RightIndex_ row_shift;
189 LeftIndex_ col_shift;
190 if (column_major_out) {
191 row_shift = 1;
192 col_shift = left.nrow();
193 } else {
194 row_shift = right.ncol();
195 col_shift = 1;
196 }
197
198 if (left.sparse()) {
199 if (left.prefer_rows()) {
200 if (right.sparse()) {
201 internal::sparse_row_tatami_sparse(left, right, output, row_shift, col_shift, num_threads);
202 } else {
203 internal::sparse_row_tatami_dense(left, right, output, row_shift, col_shift, num_threads);
204 }
205 } else {
206 if (right.sparse()) {
207 internal::sparse_column_tatami_sparse(left, right, output, row_shift, col_shift, num_threads);
208 } else {
209 internal::sparse_column_tatami_dense(left, right, output, row_shift, col_shift, num_threads);
210 }
211 }
212 } else {
213 if (left.prefer_rows()) {
214 if (right.sparse()) {
215 internal::dense_row_tatami_sparse(left, right, output, row_shift, col_shift, num_threads);
216 } else {
217 internal::dense_row_tatami_dense(left, right, output, row_shift, col_shift, num_threads);
218 }
219 } else {
220 if (right.sparse()) {
221 internal::dense_column_tatami_sparse(left, right, output, row_shift, col_shift, num_threads);
222 } else {
223 internal::dense_column_tatami_dense(left, right, output, row_shift, col_shift, num_threads);
224 }
225 }
226 }
227}
228
229}
250template<typename LeftValue_, typename LeftIndex_, typename RightValue_, typename RightIndex_, typename Output_>
251void multiply(const tatami::Matrix<LeftValue_, LeftIndex_>& left, const tatami::Matrix<RightValue_, RightIndex_>& right, Output_* output, const Options& opt) {
252 if (opt.prefer_larger) {
253 if (sanisizer::is_less_than(left.nrow(), right.ncol())) {
254 auto tright = tatami::make_DelayedTranspose(tatami::wrap_shared_ptr(&right));
255 auto tleft = tatami::make_DelayedTranspose(tatami::wrap_shared_ptr(&left));
256 internal::multiply(*tright, *tleft, output, !opt.column_major_output, opt.num_threads);
257 return;
258 }
259 }
260
261 internal::multiply(left, right, output, opt.column_major_output, opt.num_threads);
262}
263
264}
265
266#endif
virtual Index_ ncol() const=0
virtual Index_ nrow() const=0
virtual bool prefer_rows() const=0
virtual std::unique_ptr< MyopicSparseExtractor< Value_, Index_ > > sparse(bool row, const Options &opt) const=0
Multiplication of tatami matrices.
void multiply(const tatami::Matrix< Value_, Index_ > &left, const Right_ *right, Output_ *output, const Options &opt)
Definition tatami_mult.hpp:67
std::shared_ptr< const Matrix< Value_, Index_ > > wrap_shared_ptr(const Matrix< Value_, Index_ > *ptr)
Options for multiply().
Definition tatami_mult.hpp:28
bool prefer_larger
Definition tatami_mult.hpp:41
int num_threads
Definition tatami_mult.hpp:32
bool column_major_output
Definition tatami_mult.hpp:48