Skip to content

Commit

Permalink
Assert that XTensor objects are row-major.
Browse files Browse the repository at this point in the history
Easy serializes XTensor objects by obtaining a pointer to the first
element, and then using `write_raw`. Same for reading using `read_raw`.

Therefore, it only supports (a subset of) row-major arrays. This commit
as a runtime check.
  • Loading branch information
1uc committed Apr 19, 2024
1 parent c1b3cde commit 8d1a95e
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
12 changes: 12 additions & 0 deletions include/highfive/h5easy_bits/H5Easy_xtensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ namespace detail {

template <typename T>
struct io_impl<T, typename std::enable_if<xt::is_xexpression<T>::value>::type> {

inline static void assert_row_major(const File& file, const std::string& path, const T& data) {
if(data.layout() != xt::layout_type::row_major) {
throw detail::error(file, path, "Only row-major XTensor object are supported.");
}
}

inline static std::vector<size_t> shape(const T& data) {
return std::vector<size_t>(data.shape().cbegin(), data.shape().cend());
}
Expand All @@ -28,6 +35,8 @@ struct io_impl<T, typename std::enable_if<xt::is_xexpression<T>::value>::type> {
const std::string& path,
const T& data,
const DumpOptions& options) {

assert_row_major(file, path, data);
using value_type = typename std::decay_t<T>::value_type;
DataSet dataset = initDataset<value_type>(file, path, shape(data), options);
dataset.write_raw(data.data());
Expand All @@ -44,6 +53,7 @@ struct io_impl<T, typename std::enable_if<xt::is_xexpression<T>::value>::type> {
DataSet dataset = file.getDataSet(path);
std::vector<size_t> dims = dataset.getDimensions();
T data = T::from_shape(dims);
assert_row_major(file, path, data);
dataset.read_raw(data.data());
return data;
}
Expand All @@ -53,6 +63,7 @@ struct io_impl<T, typename std::enable_if<xt::is_xexpression<T>::value>::type> {
const std::string& key,
const T& data,
const DumpOptions& options) {
assert_row_major(file, path, data);
using value_type = typename std::decay_t<T>::value_type;
Attribute attribute = initAttribute<value_type>(file, path, key, shape(data), options);
attribute.write_raw(data.data());
Expand All @@ -73,6 +84,7 @@ struct io_impl<T, typename std::enable_if<xt::is_xexpression<T>::value>::type> {
DataSpace dataspace = attribute.getSpace();
std::vector<size_t> dims = dataspace.getDimensions();
T data = T::from_shape(dims);
assert_row_major(file, path, data);
attribute.read_raw(data.data());
return data;
}
Expand Down
38 changes: 38 additions & 0 deletions tests/unit/tests_high_five_easy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,44 @@ TEST_CASE("H5Easy_xtensor") {
CHECK(xt::all(xt::equal(B, B_r)));
}

TEST_CASE("H5Easy_xtensor_column_major") {
H5Easy::File file("h5easy_xtensor_colum_major.h5", H5Easy::File::Overwrite);

using column_major_t = xt::xtensor<double, 2, xt::layout_type::column_major>;

xt::xtensor<double, 2> A = 100. * xt::random::randn<double>({20, 5});

H5Easy::dump(file, "/path/to/A", A);

SECTION("Write column major") {
column_major_t B = A;
REQUIRE_THROWS(H5Easy::dump(file, "path/to/B", B));
}

SECTION("Read column major") {
REQUIRE_THROWS(H5Easy::load<column_major_t>(file, "/path/to/A"));
}
}

TEST_CASE("H5Easy_xarray_column_major") {
H5Easy::File file("h5easy_xarray_colum_major.h5", H5Easy::File::Overwrite);

using column_major_t = xt::xarray<double, xt::layout_type::column_major>;

xt::xarray<double> A = 100. * xt::random::randn<double>({20, 5});

H5Easy::dump(file, "/path/to/A", A);

SECTION("Write column major") {
column_major_t B = A;
REQUIRE_THROWS(H5Easy::dump(file, "path/to/B", B));
}

SECTION("Read column major") {
REQUIRE_THROWS(H5Easy::load<column_major_t>(file, "/path/to/A"));
}
}

TEST_CASE("H5Easy_xarray") {
H5Easy::File file("h5easy_xarray.h5", H5Easy::File::Overwrite);

Expand Down

0 comments on commit 8d1a95e

Please sign in to comment.