Skip to content

Commit

Permalink
implement reading into existing buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
llohse committed Oct 3, 2024
1 parent 471fe48 commit f6d1db9
Showing 1 changed file with 47 additions and 0 deletions.
47 changes: 47 additions & 0 deletions include/npy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,43 @@ inline npy_data<Scalar> read_npy(std::istream &in) {
return data;
}

template <typename Scalar, typename It>
inline npy_data_ptr<Scalar> read_npy(std::istream &in, It dst_buf, size_t count) {
static_assert(is_same<Scalar, std::iterator_traits<It>::value_type>,
"Scalar and It::value_type must be identical");

std::string header_s = read_header(in);

// parse header
header_t header = parse_header(header_s);

// check if the typestring matches the given one
const dtype_t dtype = dtype_map.at(std::type_index(typeid(Scalar)));

if (header.dtype.tie() != dtype.tie()) {
throw std::runtime_error("formatting error: typestrings not matching");
}

// compute the data size based on the shape
auto size = static_cast<size_t>(comp_size(header.shape));

if (size > count) {
throw std::runtime_error("dst_buf too small to hold file contents");
}

npy_data_ptr<Scalar> data;

data.shape = header.shape;
data.fortran_order = header.fortran_order;

data.data_ptr = &(*dst_buf);

// read the data
in.read(reinterpret_cast<char *>(data.data_ptr), sizeof(Scalar) * size);

return data;
}

template <typename Scalar>
inline npy_data<Scalar> read_npy(const std::string &filename) {
std::ifstream stream(filename, std::ifstream::binary);
Expand All @@ -513,6 +550,16 @@ inline npy_data<Scalar> read_npy(const std::string &filename) {
return read_npy<Scalar>(stream);
}

template <typename Scalar, typename It>
inline npy_data_ptr<Scalar> read_npy(const std::string &filename, It dst_buf, size_t count) {
std::ifstream stream(filename, std::ifstream::binary);
if (!stream) {
throw std::runtime_error("io error: failed to open a file.");
}

return read_npy<Scalar>(stream, dst_buf, count);
}

template <typename Scalar>
inline void write_npy(std::ostream &out, const npy_data<Scalar> &data) {
// static_assert(has_typestring<Scalar>::value, "scalar type not
Expand Down

0 comments on commit f6d1db9

Please sign in to comment.