Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GetPackDimension to StateDescriptor #1148

Merged
merged 9 commits into from
Aug 14, 2024
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Current develop

### Added (new features/APIs/variables/...)
- [[PR 1148]](https://github.com/parthenon-hpc-lab/parthenon/pull/1148) Add `GetPackDimension` to `StateDescriptor` for calculating pack sizes before `Mesh` initialization
- [[PR 1143]](https://github.com/parthenon-hpc-lab/parthenon/pull/1143) Add tensor indices to VariableState, add radiation constant to constants, add TypeLists, allow for arbitrary containers for solvers
- [[PR 1140]](https://github.com/parthenon-hpc-lab/parthenon/pull/1140) Allow for relative convergence tolerance in BiCGSTAB solver.
- [[PR 1047]](https://github.com/parthenon-hpc-lab/parthenon/pull/1047) General three- and four-valent 2D forests w/ arbitrary orientations.
Expand Down
32 changes: 32 additions & 0 deletions src/interface/state_descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include <iomanip>
#include <iostream>
#include <numeric>
#include <sstream>
#include <string>
#include <unordered_map>
Expand Down Expand Up @@ -566,4 +567,35 @@ StateDescriptor::GetVariableNames(const Metadata::FlagCollection &flags) {
return GetVariableNames({}, flags, {});
}

// Get the total length of this StateDescriptor's variables when packed
int StateDescriptor::GetPackDimension(const std::vector<std::string> &req_names,
const Metadata::FlagCollection &flags,
const std::vector<int> &sparse_ids) {
std::vector<std::string> names = GetVariableNames(req_names, flags, sparse_ids);
int dimension = 0;
for (auto name : names) {
const auto &meta = metadataMap_[VarID(name)];
// if meta.Shape().size() < 1, then 'accumulate' will return the initialization value,
// which is 1. Otherwise, this multiplies all elements present in 'Shape' to obtain
// total length
dimension += std::accumulate(meta.Shape().begin(), meta.Shape().end(), 1,
[](auto a, auto b) { return a * b; });
}
return dimension;
}
int StateDescriptor::GetPackDimension(const std::vector<std::string> &req_names,
const std::vector<int> &sparse_ids) {
return GetPackDimension(req_names, Metadata::FlagCollection(), sparse_ids);
}
int StateDescriptor::GetPackDimension(const Metadata::FlagCollection &flags,
const std::vector<int> &sparse_ids) {
return GetPackDimension({}, flags, sparse_ids);
}
int StateDescriptor::GetPackDimension(const std::vector<std::string> &req_names) {
return GetPackDimension(req_names, Metadata::FlagCollection(), {});
}
int StateDescriptor::GetPackDimension(const Metadata::FlagCollection &flags) {
return GetPackDimension({}, flags, {});
}

} // namespace parthenon
10 changes: 10 additions & 0 deletions src/interface/state_descriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,16 @@ class StateDescriptor {
std::vector<std::string> GetVariableNames(const std::vector<std::string> &req_names);
std::vector<std::string> GetVariableNames(const Metadata::FlagCollection &flags);

int GetPackDimension(const std::vector<std::string> &req_names,
const Metadata::FlagCollection &flags,
const std::vector<int> &sparse_ids);
int GetPackDimension(const std::vector<std::string> &req_names,
const std::vector<int> &sparse_ids);
int GetPackDimension(const Metadata::FlagCollection &flags,
const std::vector<int> &sparse_ids);
int GetPackDimension(const std::vector<std::string> &req_names);
int GetPackDimension(const Metadata::FlagCollection &flags);

std::size_t
RefinementFuncID(const refinement::RefinementFunctions_t &funcs) const noexcept {
return refinementFuncMaps_.funcs_to_ids.at(funcs);
Expand Down
16 changes: 16 additions & 0 deletions tst/unit/test_state_descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,22 @@ TEST_CASE("Test Associate in StateDescriptor", "[StateDescriptor]") {
}
}

TEST_CASE("Test GetPackDimension in StateDescriptor", "[StateDescriptor]") {
GIVEN("Some flags and state descriptors") {
StateDescriptor state("state");
WHEN("We add some fields with various shapes and total size") {
state.AddField("foo", Metadata(std::vector<MetadataFlag>{}, std::vector<int>{4}));
state.AddField("bar",
Metadata(std::vector<MetadataFlag>{}, std::vector<int>{4, 4}));
state.AddField("baz",
Metadata(std::vector<MetadataFlag>{}, std::vector<int>{4, 4, 4}));
THEN("The total length is identified correctly") {
REQUIRE(state.GetPackDimension(Metadata::GetUserFlag("state")) == 84);
}
}
}
}

TEST_CASE("Test dependency resolution in StateDescriptor", "[StateDescriptor]") {
GIVEN("Some empty state descriptors and metadata") {
// metadata
Expand Down
Loading