1#ifndef TATAMI_PYTHON_SPARSE_MATRIX_HPP
2#define TATAMI_PYTHON_SPARSE_MATRIX_HPP
5#include "pybind11/pybind11.h"
22template<
typename Type_>
23void dump_to_buffer(
const pybind11::array& input, Type_*
const buffer) {
24 auto dtype = input.dtype();
25 if (dtype.is(pybind11::dtype::of<double>())) {
26 std::copy_n(
static_cast<const double*
>(input.request().ptr), input.size(), buffer);
28 }
else if (dtype.is(pybind11::dtype::of<float>())) {
29 std::copy_n(
static_cast<const float*
>(input.request().ptr), input.size(), buffer);
31 }
else if (dtype.is(pybind11::dtype::of<std::int64_t>())) {
32 std::copy_n(
static_cast<const std::int64_t*
>(input.request().ptr), input.size(), buffer);
34 }
else if (dtype.is(pybind11::dtype::of<std::int32_t>())) {
35 std::copy_n(
static_cast<const std::int32_t*
>(input.request().ptr), input.size(), buffer);
37 }
else if (dtype.is(pybind11::dtype::of<std::int16_t>())) {
38 std::copy_n(
static_cast<const std::int16_t*
>(input.request().ptr), input.size(), buffer);
40 }
else if (dtype.is(pybind11::dtype::of<std::int8_t>())) {
41 std::copy_n(
static_cast<const std::int8_t*
>(input.request().ptr), input.size(), buffer);
43 }
else if (dtype.is(pybind11::dtype::of<std::uint64_t>())) {
44 std::copy_n(
static_cast<const std::uint64_t*
>(input.request().ptr), input.size(), buffer);
46 }
else if (dtype.is(pybind11::dtype::of<std::uint32_t>())) {
47 std::copy_n(
static_cast<const std::uint32_t*
>(input.request().ptr), input.size(), buffer);
49 }
else if (dtype.is(pybind11::dtype::of<std::uint16_t>())) {
50 std::copy_n(
static_cast<const std::uint16_t*
>(input.request().ptr), input.size(), buffer);
52 }
else if (dtype.is(pybind11::dtype::of<std::uint8_t>())) {
53 std::copy_n(
static_cast<const std::uint8_t*
>(input.request().ptr), input.size(), buffer);
56 throw std::runtime_error(
"unrecognized array type '" + std::string(dtype.kind(), 1) + std::to_string(dtype.itemsize()) +
"' from 'extract_sparse_array()'");
87template<
typename Value_,
typename Index_,
class Function_>
88void parse_Sparse2darray(
const pybind11::object& matrix, Value_*
const vbuffer, Index_*
const ibuffer, Function_ fun) {
89 pybind11::object raw_svt = matrix.attr(
"contents");
90 if (pybind11::isinstance<pybind11::none>(raw_svt)) {
93 auto svt = raw_svt.template cast<pybind11::list>();
95 const auto shape = get_shape<Index_>(matrix);
96 const auto NC = shape.second;
98 for (I<
decltype(NC)> c = 0; c < NC; ++c) {
99 pybind11::object raw_inner(svt[c]);
100 if (pybind11::isinstance<pybind11::none>(raw_inner)) {
104 auto inner = raw_inner.template cast<pybind11::tuple>();
105 if (inner.size() != 2) {
106 auto ctype = get_class_name(matrix);
107 throw std::runtime_error(
"each entry of '<" + ctype +
">.contents' should be a tuple of length 2 or None");
110 auto iinput = inner[0].template cast<pybind11::array>();
111 if (ibuffer != NULL) {
112 dump_to_buffer(iinput, ibuffer);
114 if (vbuffer != NULL) {
115 auto vinput = inner[1].template cast<pybind11::array>();
116 dump_to_buffer(vinput, vbuffer);
121 fun(c,
static_cast<Index_
>(iinput.size()));
128template<
typename CachedValue_,
typename CachedIndex_,
typename Index_>
129void parse_sparse_matrix(
130 const pybind11::object& matrix,
132 std::vector<CachedValue_*>& value_ptrs,
133 CachedValue_*
const vbuffer,
134 std::vector<CachedIndex_*>& index_ptrs,
135 CachedIndex_*
const ibuffer,
138 const bool needs_value = !value_ptrs.empty();
139 const bool needs_index = !index_ptrs.empty();
143 (needs_value ? vbuffer : NULL),
144 (needs_index || row ? ibuffer : NULL),
145 [&](
const Index_ c,
const Index_ nnz) ->
void {
150 for (I<
decltype(nnz)> i = 0; i < nnz; ++i) {
151 auto ix = ibuffer[i];
152 value_ptrs[ix][counts[ix]] = vbuffer[i];
156 for (I<
decltype(nnz)> i = 0; i < nnz; ++i) {
157 auto ix = ibuffer[i];
158 index_ptrs[ix][counts[ix]] = c;
161 for (I<
decltype(nnz)> i = 0; i < nnz; ++i) {
162 ++(counts[ibuffer[i]]);
167 std::copy_n(vbuffer, nnz, value_ptrs[c]);
170 std::copy_n(ibuffer, nnz, index_ptrs[c]);
tatami bindings for arbitrary Python matrices.
void parse_Sparse2darray(const pybind11::object &matrix, Value_ *const vbuffer, Index_ *const ibuffer, Function_ fun)
Definition sparse_matrix.hpp:88