Skip to content

Commit

Permalink
Select product of slices.
Browse files Browse the repository at this point in the history
  • Loading branch information
1uc committed Nov 9, 2023
1 parent 88fcc89 commit 889cd13
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 0 deletions.
16 changes: 16 additions & 0 deletions include/highfive/bits/H5Slice_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,20 @@ class HyperSlab {
std::vector<Select_> selects;
};

class ProductSet {
public:
template <class... Slices>
explicit ProductSet(const Slices&... slices);

private:
HyperSlab slab;
std::vector<size_t> shape;

template <typename Derivate>
friend class SliceTraits;
};


template <typename Derivate>
class SliceTraits {
public:
Expand Down Expand Up @@ -291,6 +305,8 @@ class SliceTraits {
///
Selection select(const ElementSet& elements) const;

Selection select(const ProductSet& product_set) const;

template <typename T>
T read(const DataTransferProps& xfer_props = DataTransferProps()) const;

Expand Down
126 changes: 126 additions & 0 deletions include/highfive/bits/H5Slice_traits_misc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,127 @@ inline ElementSet::ElementSet(const std::vector<std::vector<std::size_t>>& eleme
}
}

namespace detail {
class HyperCube {
public:
HyperCube(size_t rank)
: offset(rank)
, count(rank) {}

void cross(const std::array<size_t, 2>& range, size_t axis) {
offset[axis] = range[0];
count[axis] = range[1] - range[0];
}

RegularHyperSlab asSlab() {
return RegularHyperSlab(offset, count);
}

private:
std::vector<size_t> offset;
std::vector<size_t> count;
};

void build_hyper_slab(HyperSlab& slab, size_t /* axis */, HyperCube& cube) {
slab |= cube.asSlab();
}

template <class... Slices>
void build_hyper_slab(HyperSlab& slab,
size_t axis,
HyperCube& cube,
const std::array<size_t, 2>& slice,
const Slices&... higher_slices) {
cube.cross(slice, axis);
build_hyper_slab(slab, axis + 1, cube, higher_slices...);
}

template <class... Slices>
void build_hyper_slab(HyperSlab& slab,
size_t axis,
HyperCube& cube,
const std::vector<std::array<size_t, 2>>& slices,
const Slices&... higher_slices) {
for (const auto& slice: slices) {
build_hyper_slab(slab, axis, cube, slice, higher_slices...);
}
}

template <class... Slices>
void build_hyper_slab(HyperSlab& slab,
size_t axis,
HyperCube& cube,
const std::vector<size_t>& ids,
const Slices&... higher_slices) {
for (const auto& id: ids) {
build_hyper_slab(slab, axis, cube, {id, id + 1}, higher_slices...);
}
}

void compute_squashed_shape(size_t /* axis */, std::vector<size_t>& /* shape */) {
// assert(axis == shape.size());
}

Check warning on line 125 in include/highfive/bits/H5Slice_traits_misc.hpp

View check run for this annotation

Codecov / codecov/patch

include/highfive/bits/H5Slice_traits_misc.hpp#L125

Added line #L125 was not covered by tests

template <class... Slices>
void compute_squashed_shape(size_t axis,
std::vector<size_t>& shape,
const std::array<size_t, 2>& slice,
const Slices&... higher_slices);

template <class... Slices>
void compute_squashed_shape(size_t axis,
std::vector<size_t>& shape,
const std::vector<size_t>& points,
const Slices&... higher_slices);

template <class... Slices>
void compute_squashed_shape(size_t axis,
std::vector<size_t>& shape,
const std::vector<std::array<size_t, 2>>& slices,
const Slices&... higher_slices);

template <class... Slices>
void compute_squashed_shape(size_t axis,
std::vector<size_t>& shape,
const std::array<size_t, 2>& slice,
const Slices&... higher_slices) {
shape[axis] = slice[1] - slice[0];
compute_squashed_shape(axis + 1, shape, higher_slices...);
}

template <class... Slices>
void compute_squashed_shape(size_t axis,
std::vector<size_t>& shape,
const std::vector<size_t>& points,
const Slices&... higher_slices) {
shape[axis] = points.size();
compute_squashed_shape(axis + 1, shape, higher_slices...);
}

template <class... Slices>
void compute_squashed_shape(size_t axis,
std::vector<size_t>& shape,
const std::vector<std::array<size_t, 2>>& slices,
const Slices&... higher_slices) {
shape[axis] = 0;
for (const auto& slice: slices) {
shape[axis] += slice[1] - slice[0];
}
compute_squashed_shape(axis + 1, shape, higher_slices...);
}
} // namespace detail

template <class... Slices>
ProductSet::ProductSet(const Slices&... slices) {
auto rank = sizeof...(slices);
detail::HyperCube cube(rank);
detail::build_hyper_slab(slab, 0, cube, slices...);

shape = std::vector<size_t>(rank, size_t(0));
detail::compute_squashed_shape(0, shape, slices...);
}


template <typename Derivate>
inline Selection SliceTraits<Derivate>::select(const HyperSlab& hyperslab,
const DataSpace& memspace) const {
Expand Down Expand Up @@ -156,6 +277,11 @@ inline Selection SliceTraits<Derivate>::select(const ElementSet& elements) const
return detail::make_selection(DataSpace(num_elements), space, details::get_dataset(slice));
}

template <typename Derivate>
inline Selection SliceTraits<Derivate>::select(const ProductSet& product_set) const {
return this->select(product_set.slab, DataSpace(product_set.shape));
}


template <typename Derivate>
template <typename T>
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/tests_high_five_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1617,6 +1617,32 @@ TEMPLATE_LIST_TEST_CASE("irregularHyperSlabSelectionWrite", "[template]", std::t
irregularHyperSlabSelectionWriteTest<TestType>();
}

TEST_CASE("productSet_rR") {
const std::string file_name("h5_test_product_set_rRp.h5");

// clang-format off
std::vector<std::vector<double>> array{
{0.0, 0.1, 0.2, 0.3, 0.4, 0.5},
{1.0, 1.1, 1.2, 1.3, 1.4, 1.5},
{2.0, 2.1, 2.2, 2.3, 2.4, 2.5},
{3.0, 3.1, 3.2, 3.3, 3.4, 3.5},
{4.0, 4.1, 4.2, 4.3, 4.4, 4.5},
{5.0, 5.1, 5.2, 5.3, 5.4, 5.5},
{6.0, 6.1, 6.2, 6.3, 6.4, 6.5},
{7.0, 7.1, 7.2, 7.3, 7.4, 7.5}
};
// clang-format on

auto file = File(file_name, File::Truncate);
auto dset = file.createDataSet("dset", array);

std::vector<std::vector<double>> subarray;
dset.select(ProductSet(std::array<size_t, 2>{1, 3},
std::vector<std::array<size_t, 2>>{{0, 1}, {3, 5}}))
.read(subarray);
}


template <typename T>
void attribute_scalar_rw() {
std::ostringstream filename;
Expand Down

0 comments on commit 889cd13

Please sign in to comment.