diff --git a/include/highfive/xtensor.hpp b/include/highfive/xtensor.hpp new file mode 100644 index 000000000..729a1e349 --- /dev/null +++ b/include/highfive/xtensor.hpp @@ -0,0 +1,212 @@ +#pragma once + +#include "bits/H5Inspector_decl.hpp" +#include "H5Exception.hpp" + +#include +#include +#include + +namespace HighFive { +namespace details { + +template +struct xtensor_get_rank; + +template +struct xtensor_get_rank> { + static constexpr size_t value = N; +}; + +template +struct xtensor_get_rank> { + static constexpr size_t value = N; +}; + +template +struct xtensor_inspector_base { + using type = XTensorType; + using value_type = typename type::value_type; + using base_type = typename inspector::base_type; + using hdf5_type = base_type; + + static_assert(std::is_same::value, + "HighFive's XTensor support only works for scalar elements."); + + static constexpr bool IsConstExprRowMajor = L == xt::layout_type::row_major; + static constexpr bool is_trivially_copyable = IsConstExprRowMajor && + std::is_trivially_copyable::value && + inspector::is_trivially_copyable; + + static constexpr bool is_trivially_nestable = false; + + static size_t getRank(const type& val) { + // Non-scalar elements are not supported. + return val.shape().size(); + } + + static const value_type& getAnyElement(const type& val) { + return val.unchecked(0); + } + + static value_type& getAnyElement(type& val) { + return val.unchecked(0); + } + + static std::vector getDimensions(const type& val) { + auto shape = val.shape(); + return {shape.begin(), shape.end()}; + } + + static void prepare(type& val, const std::vector& dims) { + val.resize(Derived::shapeFromDims(dims)); + } + + static hdf5_type* data(type& val) { + if (!is_trivially_copyable) { + throw DataSetException("Invalid used of `inspector::data`."); + } + + if (val.size() == 0) { + return nullptr; + } + + return inspector::data(getAnyElement(val)); + } + + static const hdf5_type* data(const type& val) { + if (!is_trivially_copyable) { + throw DataSetException("Invalid used of `inspector::data`."); + } + + if (val.size() == 0) { + return nullptr; + } + + return inspector::data(getAnyElement(val)); + } + + static void serialize(const type& val, const std::vector& dims, hdf5_type* m) { + // since we only support scalar types we know all dims belong to us. + size_t size = compute_total_size(dims); + xt::adapt(m, size, xt::no_ownership(), dims) = val; + } + + static void unserialize(const hdf5_type* vec_align, + const std::vector& dims, + type& val) { + // since we only support scalar types we know all dims belong to us. + size_t size = compute_total_size(dims); + val = xt::adapt(vec_align, size, xt::no_ownership(), dims); + } +}; + +template +struct xtensor_inspector + : public xtensor_inspector_base, XTensorType, L> { + private: + using super = xtensor_inspector_base, XTensorType, L>; + + public: + using type = typename super::type; + using value_type = typename super::value_type; + using base_type = typename super::base_type; + using hdf5_type = typename super::hdf5_type; + + static constexpr size_t ndim = xtensor_get_rank::value; + static constexpr size_t min_ndim = ndim + inspector::min_ndim; + static constexpr size_t max_ndim = ndim + inspector::max_ndim; + + static std::array shapeFromDims(const std::vector& dims) { + std::array shape; + std::copy(dims.cbegin(), dims.cend(), shape.begin()); + return shape; + } +}; + +template +struct xarray_inspector + : public xtensor_inspector_base, XArrayType, L> { + private: + using super = xtensor_inspector_base, XArrayType, L>; + + public: + using type = typename super::type; + using value_type = typename super::value_type; + using base_type = typename super::base_type; + using hdf5_type = typename super::hdf5_type; + + static constexpr size_t min_ndim = 0 + inspector::min_ndim; + static constexpr size_t max_ndim = 1024 + inspector::max_ndim; + + static const std::vector& shapeFromDims(const std::vector& dims) { + return dims; + } +}; + +template +struct inspector>: public xtensor_inspector, L> { + private: + using super = xtensor_inspector, L>; + + public: + using type = typename super::type; + using value_type = typename super::value_type; + using base_type = typename super::base_type; + using hdf5_type = typename super::hdf5_type; +}; + +template +struct inspector>: public xarray_inspector, L> { + private: + using super = xarray_inspector, L>; + + public: + using type = typename super::type; + using value_type = typename super::value_type; + using base_type = typename super::base_type; + using hdf5_type = typename super::hdf5_type; +}; + +template +struct inspector> + : public xarray_inspector, xt::layout_type::any> { + private: + using super = xarray_inspector, xt::layout_type::any>; + + public: + using type = typename super::type; + using value_type = typename super::value_type; + using base_type = typename super::base_type; + using hdf5_type = typename super::hdf5_type; +}; + + +template +struct inspector> + : public xarray_inspector, xt::layout_type::any> { + private: + using super = xarray_inspector, xt::layout_type::any>; + + public: + using type = typename super::type; + using value_type = typename super::value_type; + using base_type = typename super::base_type; + using hdf5_type = typename super::hdf5_type; +}; + +template +struct inspector> + : public xtensor_inspector, xt::layout_type::any> { + private: + using super = xtensor_inspector, xt::layout_type::any>; + + public: + using type = typename super::type; + using value_type = typename super::value_type; + using base_type = typename super::base_type; + using hdf5_type = typename super::hdf5_type; +}; + +} // namespace details +} // namespace HighFive diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index 0f795812d..c8835ba34 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -6,7 +6,7 @@ if(MSVC) endif() ## Base tests -foreach(test_name tests_high_five_base tests_high_five_multi_dims tests_high_five_easy test_all_types test_high_five_selection tests_high_five_data_type test_empty_arrays test_legacy test_opencv test_string) +foreach(test_name tests_high_five_base tests_high_five_multi_dims tests_high_five_easy test_all_types test_high_five_selection tests_high_five_data_type test_empty_arrays test_legacy test_opencv test_string test_xtensor) add_executable(${test_name} "${test_name}.cpp") target_link_libraries(${test_name} HighFive HighFiveWarnings HighFiveFlags Catch2::Catch2WithMain) target_link_libraries(${test_name} HighFiveOptionalDependencies) @@ -47,7 +47,7 @@ endif() # test succeeds if it compiles. file(GLOB public_headers LIST_DIRECTORIES false RELATIVE ${PROJECT_SOURCE_DIR}/include CONFIGURE_DEPENDS ${PROJECT_SOURCE_DIR}/include/highfive/*.hpp) foreach(PUBLIC_HEADER ${public_headers}) - if(PUBLIC_HEADER STREQUAL "highfive/span.hpp" AND NOT HIGHFIVE_TEST_SPAN) + if(PUBLIC_HEADER STREQUAL "highfive/span.hpp" AND NOT HIGHFIVE_TEST_SPAN) continue() endif() @@ -67,6 +67,10 @@ foreach(PUBLIC_HEADER ${public_headers}) continue() endif() + if(PUBLIC_HEADER STREQUAL "highfive/xtensor.hpp" AND NOT HIGHFIVE_TEST_XTENSOR) + continue() + endif() + get_filename_component(CLASS_NAME ${PUBLIC_HEADER} NAME_WE) configure_file(tests_import_public_headers.cpp "tests_${CLASS_NAME}.cpp" @ONLY) add_executable("tests_include_${CLASS_NAME}" "${CMAKE_CURRENT_BINARY_DIR}/tests_${CLASS_NAME}.cpp") diff --git a/tests/unit/data_generator.hpp b/tests/unit/data_generator.hpp index a50284d5c..d513c3420 100644 --- a/tests/unit/data_generator.hpp +++ b/tests/unit/data_generator.hpp @@ -21,11 +21,16 @@ #include #endif +#ifdef HIGHFIVE_TEST_XTENSOR +#include +#endif + namespace HighFive { namespace testing { -std::vector lstrip(const std::vector& indices, size_t n) { +template +std::vector lstrip(const Dims& indices, size_t n) { std::vector subindices(indices.size() - n); for (size_t i = 0; i < subindices.size(); ++i) { subindices[i] = indices[i + n]; @@ -34,7 +39,8 @@ std::vector lstrip(const std::vector& indices, size_t n) { return subindices; } -size_t ravel(std::vector& indices, const std::vector dims) { +template +size_t ravel(std::vector& indices, const Dims& dims) { size_t rank = dims.size(); size_t linear_index = 0; size_t ld = 1; @@ -47,7 +53,8 @@ size_t ravel(std::vector& indices, const std::vector dims) { return linear_index; } -std::vector unravel(size_t flat_index, const std::vector dims) { +template +std::vector unravel(size_t flat_index, const Dims& dims) { size_t rank = dims.size(); size_t ld = 1; std::vector indices(rank); @@ -60,7 +67,8 @@ std::vector unravel(size_t flat_index, const std::vector dims) { return indices; } -static size_t flat_size(const std::vector& dims) { +template +static size_t flat_size(const Dims& dims) { size_t n = 1; for (auto d: dims) { n *= d; @@ -388,6 +396,7 @@ struct ContainerTraits> { #endif +// -- Eigen ------------------------------------------------------------------- #if HIGHFIVE_TEST_EIGEN template @@ -525,6 +534,88 @@ struct ContainerTraits> }; +#endif + +// -- XTensor ----------------------------------------------------------------- + +#if HIGHFIVE_TEST_XTENSOR +template +struct XTensorContainerTraits { + using container_type = XTensorType; + using value_type = typename container_type::value_type; + using base_type = typename ContainerTraits::base_type; + + static constexpr size_t rank = Rank; + static constexpr bool is_view = ContainerTraits::is_view; + + static void set(container_type& array, + const std::vector& indices, + const base_type& value) { + std::vector local_indices(indices.begin(), indices.begin() + rank); + return ContainerTraits::set(array[local_indices], lstrip(indices, rank), value); + } + + static base_type get(const container_type& array, const std::vector& indices) { + std::vector local_indices(indices.begin(), indices.begin() + rank); + return ContainerTraits::get(array[local_indices], lstrip(indices, rank)); + } + + static void assign(container_type& dst, const container_type& src) { + dst = src; + } + + static container_type allocate(const std::vector& dims) { + const auto& local_dims = details::inspector::shapeFromDims(dims); + auto array = container_type(local_dims); + + size_t n_elements = flat_size(local_dims); + for (size_t i = 0; i < n_elements; ++i) { + auto element = ContainerTraits::allocate(lstrip(dims, rank)); + set(array, unravel(i, local_dims), element); + } + + return array; + } + + static void deallocate(container_type& array, const std::vector& dims) { + auto local_dims = std::vector(dims.begin(), dims.begin() + rank); + size_t n_elements = flat_size(local_dims); + for (size_t i_flat = 0; i_flat < n_elements; ++i_flat) { + auto indices = unravel(i_flat, local_dims); + std::vector local_indices(indices.begin(), indices.begin() + rank); + ContainerTraits::deallocate(array[local_indices], lstrip(dims, rank)); + } + } + + static void sanitize_dims(std::vector& dims, size_t axis) { + ContainerTraits::sanitize_dims(dims, axis + rank); + } +}; + +template +struct ContainerTraits> + : public XTensorContainerTraits, rank> { + private: + using super = XTensorContainerTraits, rank>; + + public: + using container_type = typename super::container_type; + using value_type = typename super::value_type; + using base_type = typename super::base_type; +}; + +template +struct ContainerTraits> + : public XTensorContainerTraits, 2> { + private: + using super = XTensorContainerTraits, 2>; + + public: + using container_type = typename super::container_type; + using value_type = typename super::value_type; + using base_type = typename super::base_type; +}; + #endif template diff --git a/tests/unit/supported_types.hpp b/tests/unit/supported_types.hpp index 4d703949d..5f1bf6453 100644 --- a/tests/unit/supported_types.hpp +++ b/tests/unit/supported_types.hpp @@ -82,6 +82,20 @@ struct EigenMapMatrix { }; #endif +#ifdef HIGHFIVE_TEST_XTENSOR +template +struct XTensor { + template + using type = xt::xtensor, rank, layout>; +}; + +template +struct XArray { + template + using type = xt::xarray, layout>; +}; +#endif + template struct ContainerProduct; @@ -165,6 +179,16 @@ using supported_array_types = typename ConcatenateTuples< typename ContainerProduct>, some_scalar_types>::type, typename ContainerProduct>, some_scalar_types>::type, typename ContainerProduct>, some_scalar_types>::type, +#endif +#ifdef HIGHFIVE_TEST_XTENSOR + typename ContainerProduct, scalar_types_eigen>::type, + typename ContainerProduct>, scalar_types_eigen>::type, + typename ContainerProduct>, scalar_types_eigen>::type, + typename ContainerProduct, scalar_types_eigen>::type, + typename ContainerProduct, scalar_types_eigen>::type, + typename ContainerProduct>, scalar_types_eigen>::type, + typename ContainerProduct>, scalar_types_eigen>::type, + typename ContainerProduct, scalar_types_eigen>::type, #endif typename ContainerProduct, all_scalar_types>::type, typename ContainerProduct>, some_scalar_types>::type, diff --git a/tests/unit/test_xtensor.cpp b/tests/unit/test_xtensor.cpp new file mode 100644 index 000000000..ac0b4d743 --- /dev/null +++ b/tests/unit/test_xtensor.cpp @@ -0,0 +1,142 @@ +/* + * Copyright (c), 2024, Blue Brain Project - EPFL + * + * Distributed under the Boost Software License, Version 1.0. + * (See accompanying file LICENSE_1_0.txt or copy at + * http://www.boost.org/LICENSE_1_0.txt) + * + */ +#if HIGHFIVE_TEST_XTENSOR +#include +#include + +#include + +#include +#include +#include +#include +#include + +#include "data_generator.hpp" + +using namespace HighFive; + +template +std::array asStaticShape(const std::vector& dims) { + assert(dims.size() == N); + + std::array shape; + std::copy(dims.cbegin(), dims.cend(), shape.begin()); + + return shape; +} + +TEST_CASE("xt::xarray reshape", "[xtensor]") { + const std::string file_name("rw_dataset_xarray.h5"); + + File file(file_name, File::Truncate); + + std::vector shape{3, 2, 4}; + std::vector compatible_shape{1, 3, 2, 4}; + std::vector incompatible_shape{5, 2, 4}; + + xt::xarray a = testing::DataGenerator>::create(shape); + xt::xarray b(compatible_shape); + xt::xarray c(incompatible_shape); + + auto dset = file.createDataSet("baz", a); + + SECTION("xarray_adaptor") { + // Changes the shape. + auto b_adapt = xt::adapt(b.data(), b.size(), xt::no_ownership(), b.shape()); + dset.read(b_adapt); + REQUIRE(b_adapt.shape() == shape); + + // But can't change the number of elements. + auto c_adapt = xt::adapt(c.data(), c.size(), xt::no_ownership(), c.shape()); + REQUIRE_THROWS(dset.read(c_adapt)); + } + + SECTION("xtensor_adaptor") { + auto b_shape = asStaticShape<4>(compatible_shape); + auto c_shape = asStaticShape<3>(incompatible_shape); + + // Doesn't change the shape: + auto b_adapt = xt::adapt(b.data(), b.size(), xt::no_ownership(), b_shape); + REQUIRE_THROWS(dset.read(b_adapt)); + + // and can't change the number of elements: + auto c_adapt = xt::adapt(c.data(), c.size(), xt::no_ownership(), c_shape); + REQUIRE_THROWS(dset.read(c_adapt)); + } +} + +TEST_CASE("xt::xview example", "[xtensor]") { + File file("rw_dataset_xview.h5", File::Truncate); + + std::vector shape{13, 5, 7}; + xt::xarray a = testing::DataGenerator>::create(shape); + auto c = xt::view(a, xt::range(3, 31, 4), xt::all(), xt::drop(0, 3, 4, 5)); + + auto dset = file.createDataSet("c", c); + auto d = dset.read>(); + auto e = dset.read>(); + + REQUIRE(d == c); + REQUIRE(e == c); +} + +template +void check_xtensor_scalar(File& file) { + XTensor a; + a = 42.0; + REQUIRE(a.shape() == std::vector{}); + + SECTION("read") { + auto dset = file.createDataSet("a", a); + REQUIRE(dset.template read() == a(0)); + } + + SECTION("write") { + double b = -42.0; + auto dset = file.createDataSet("b", b); + REQUIRE(dset.template read>()(0) == b); + } +} + +TEST_CASE("xt::xarray scalar", "[xtensor]") { + File file("rw_dataset_xarray_scalar.h5", File::Truncate); + check_xtensor_scalar>(file); +} + +TEST_CASE("xt::xtensor scalar", "[xtensor]") { + File file("rw_dataset_xtensor_scalar.h5", File::Truncate); + check_xtensor_scalar>(file); +} + +template +void check_xtensor_empty(File& file, const XTensor& a, const std::vector& expected_dims) { + auto dset = file.createDataSet("a", a); + auto b = dset.template read(); + REQUIRE(b.size() == 0); + REQUIRE(b == a); + + auto c = std::vector{}; + auto c_shape = details::inspector::getDimensions(c); + REQUIRE(c_shape == expected_dims); +} + +TEST_CASE("xt::xtensor empty", "[xtensor]") { + File file("rw_dataset_xtensor_empty.h5", File::Truncate); + xt::xtensor a({0, 1, 1}); + check_xtensor_empty(file, a, {0, 1, 1, 1}); +} + +TEST_CASE("xt::xarray empty", "[xtensor]") { + File file("rw_dataset_xarray_empty.h5", File::Truncate); + xt::xarray a(std::vector{1, 0, 1}); + check_xtensor_empty(file, a, {0}); +} + +#endif diff --git a/tests/unit/tests_high_five.hpp b/tests/unit/tests_high_five.hpp index 25839c69e..d9a4ed34d 100644 --- a/tests/unit/tests_high_five.hpp +++ b/tests/unit/tests_high_five.hpp @@ -21,6 +21,7 @@ // The list of identifiers is taken from `Boost::Predef`. #if defined(_WIN32) || defined(_WIN64) || defined(__WIN32__) || defined(__TOS_WIN__) || \ defined(__WINDOWS__) +#define NOMINMAX #include #endif