tatami_python
R bindings to tatami matrices
Loading...
Searching...
No Matches
sparse_matrix.hpp
Go to the documentation of this file.
1#ifndef TATAMI_PYTHON_SPARSE_MATRIX_HPP
2#define TATAMI_PYTHON_SPARSE_MATRIX_HPP
3
4#include "tatami/tatami.hpp"
5#include "pybind11/pybind11.h"
6
7#include "utils.hpp"
8
9#include <algorithm>
10#include <cstdint>
11
17namespace tatami_python {
18
22template<typename Type_>
23void dump_to_buffer(const pybind11::array& input, Type_* const buffer) {
24 auto dtype = input.dtype();
25 if (dtype.is(pybind11::dtype::of<double>())) {
26 std::copy_n(static_cast<const double*>(input.request().ptr), input.size(), buffer);
27
28 } else if (dtype.is(pybind11::dtype::of<float>())) {
29 std::copy_n(static_cast<const float*>(input.request().ptr), input.size(), buffer);
30
31 } else if (dtype.is(pybind11::dtype::of<std::int64_t>())) {
32 std::copy_n(static_cast<const std::int64_t*>(input.request().ptr), input.size(), buffer);
33
34 } else if (dtype.is(pybind11::dtype::of<std::int32_t>())) {
35 std::copy_n(static_cast<const std::int32_t*>(input.request().ptr), input.size(), buffer);
36
37 } else if (dtype.is(pybind11::dtype::of<std::int16_t>())) {
38 std::copy_n(static_cast<const std::int16_t*>(input.request().ptr), input.size(), buffer);
39
40 } else if (dtype.is(pybind11::dtype::of<std::int8_t>())) {
41 std::copy_n(static_cast<const std::int8_t*>(input.request().ptr), input.size(), buffer);
42
43 } else if (dtype.is(pybind11::dtype::of<std::uint64_t>())) {
44 std::copy_n(static_cast<const std::uint64_t*>(input.request().ptr), input.size(), buffer);
45
46 } else if (dtype.is(pybind11::dtype::of<std::uint32_t>())) {
47 std::copy_n(static_cast<const std::uint32_t*>(input.request().ptr), input.size(), buffer);
48
49 } else if (dtype.is(pybind11::dtype::of<std::uint16_t>())) {
50 std::copy_n(static_cast<const std::uint16_t*>(input.request().ptr), input.size(), buffer);
51
52 } else if (dtype.is(pybind11::dtype::of<std::uint8_t>())) {
53 std::copy_n(static_cast<const std::uint8_t*>(input.request().ptr), input.size(), buffer);
54
55 } else {
56 throw std::runtime_error("unrecognized array type '" + std::string(dtype.kind(), 1) + std::to_string(dtype.itemsize()) + "' from 'extract_sparse_array()'");
57 }
58}
87template<typename Value_, typename Index_, class Function_>
88void parse_Sparse2darray(const pybind11::object& matrix, Value_* const vbuffer, Index_* const ibuffer, Function_ fun) {
89 pybind11::object raw_svt = matrix.attr("contents");
90 if (pybind11::isinstance<pybind11::none>(raw_svt)) {
91 return;
92 }
93 auto svt = raw_svt.template cast<pybind11::list>();
94
95 const auto shape = get_shape<Index_>(matrix);
96 const auto NC = shape.second;
97
98 for (I<decltype(NC)> c = 0; c < NC; ++c) {
99 pybind11::object raw_inner(svt[c]);
100 if (pybind11::isinstance<pybind11::none>(raw_inner)) {
101 continue;
102 }
103
104 auto inner = raw_inner.template cast<pybind11::tuple>();
105 if (inner.size() != 2) {
106 auto ctype = get_class_name(matrix);
107 throw std::runtime_error("each entry of '<" + ctype + ">.contents' should be a tuple of length 2 or None");
108 }
109
110 auto iinput = inner[0].template cast<pybind11::array>();
111 if (ibuffer != NULL) {
112 dump_to_buffer(iinput, ibuffer);
113 }
114 if (vbuffer != NULL) {
115 auto vinput = inner[1].template cast<pybind11::array>();
116 dump_to_buffer(vinput, vbuffer);
117 }
118
119 // cast is known to be safe as the length of these vectors cannot excced
120 // the number of rows, the latter of which must fit in an Index_.
121 fun(c, static_cast<Index_>(iinput.size()));
122 }
123}
124
128template<typename CachedValue_, typename CachedIndex_, typename Index_>
129void parse_sparse_matrix(
130 const pybind11::object& matrix,
131 bool row,
132 std::vector<CachedValue_*>& value_ptrs,
133 CachedValue_* const vbuffer,
134 std::vector<CachedIndex_*>& index_ptrs,
135 CachedIndex_* const ibuffer,
136 Index_* const counts
137) {
138 const bool needs_value = !value_ptrs.empty();
139 const bool needs_index = !index_ptrs.empty();
140
142 matrix,
143 (needs_value ? vbuffer : NULL),
144 (needs_index || row ? ibuffer : NULL),
145 [&](const Index_ c, const Index_ nnz) -> void {
146 // Note that non-empty value_ptrs and index_ptrs may be longer than the
147 // number of rows/columns in the SVT matrix, due to the reuse of slabs.
148 if (row) {
149 if (needs_value) {
150 for (I<decltype(nnz)> i = 0; i < nnz; ++i) {
151 auto ix = ibuffer[i];
152 value_ptrs[ix][counts[ix]] = vbuffer[i];
153 }
154 }
155 if (needs_index) {
156 for (I<decltype(nnz)> i = 0; i < nnz; ++i) {
157 auto ix = ibuffer[i];
158 index_ptrs[ix][counts[ix]] = c;
159 }
160 }
161 for (I<decltype(nnz)> i = 0; i < nnz; ++i) {
162 ++(counts[ibuffer[i]]);
163 }
164
165 } else {
166 if (needs_value) {
167 std::copy_n(vbuffer, nnz, value_ptrs[c]);
168 }
169 if (needs_index) {
170 std::copy_n(ibuffer, nnz, index_ptrs[c]);
171 }
172 counts[c] = nnz;
173 }
174 }
175 );
176}
181}
182
183#endif
tatami bindings for arbitrary Python matrices.
void parse_Sparse2darray(const pybind11::object &matrix, Value_ *const vbuffer, Index_ *const ibuffer, Function_ fun)
Definition sparse_matrix.hpp:88