diff --git a/CHANGELOG.md b/CHANGELOG.md index c404edb308a7..2cfc0df66b46 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## Current develop ### Added (new features/APIs/variables/...) +- [[PR 1148]](https://github.com/parthenon-hpc-lab/parthenon/pull/1148) Add `GetPackDimension` to `StateDescriptor` for calculating pack sizes before `Mesh` initialization - [[PR 1143]](https://github.com/parthenon-hpc-lab/parthenon/pull/1143) Add tensor indices to VariableState, add radiation constant to constants, add TypeLists, allow for arbitrary containers for solvers - [[PR 1140]](https://github.com/parthenon-hpc-lab/parthenon/pull/1140) Allow for relative convergence tolerance in BiCGSTAB solver. - [[PR 1047]](https://github.com/parthenon-hpc-lab/parthenon/pull/1047) General three- and four-valent 2D forests w/ arbitrary orientations. diff --git a/src/interface/state_descriptor.cpp b/src/interface/state_descriptor.cpp index 308ba14a7b67..aa3fa5e6af27 100644 --- a/src/interface/state_descriptor.cpp +++ b/src/interface/state_descriptor.cpp @@ -13,6 +13,7 @@ #include #include +#include #include #include #include @@ -566,4 +567,35 @@ StateDescriptor::GetVariableNames(const Metadata::FlagCollection &flags) { return GetVariableNames({}, flags, {}); } +// Get the total length of this StateDescriptor's variables when packed +int StateDescriptor::GetPackDimension(const std::vector &req_names, + const Metadata::FlagCollection &flags, + const std::vector &sparse_ids) { + std::vector names = GetVariableNames(req_names, flags, sparse_ids); + int dimension = 0; + for (auto name : names) { + const auto &meta = metadataMap_[VarID(name)]; + // if meta.Shape().size() < 1, then 'accumulate' will return the initialization value, + // which is 1. Otherwise, this multiplies all elements present in 'Shape' to obtain + // total length + dimension += std::accumulate(meta.Shape().begin(), meta.Shape().end(), 1, + [](auto a, auto b) { return a * b; }); + } + return dimension; +} +int StateDescriptor::GetPackDimension(const std::vector &req_names, + const std::vector &sparse_ids) { + return GetPackDimension(req_names, Metadata::FlagCollection(), sparse_ids); +} +int StateDescriptor::GetPackDimension(const Metadata::FlagCollection &flags, + const std::vector &sparse_ids) { + return GetPackDimension({}, flags, sparse_ids); +} +int StateDescriptor::GetPackDimension(const std::vector &req_names) { + return GetPackDimension(req_names, Metadata::FlagCollection(), {}); +} +int StateDescriptor::GetPackDimension(const Metadata::FlagCollection &flags) { + return GetPackDimension({}, flags, {}); +} + } // namespace parthenon diff --git a/src/interface/state_descriptor.hpp b/src/interface/state_descriptor.hpp index 47f1a9138d09..c488d53a1bcd 100644 --- a/src/interface/state_descriptor.hpp +++ b/src/interface/state_descriptor.hpp @@ -276,6 +276,16 @@ class StateDescriptor { std::vector GetVariableNames(const std::vector &req_names); std::vector GetVariableNames(const Metadata::FlagCollection &flags); + int GetPackDimension(const std::vector &req_names, + const Metadata::FlagCollection &flags, + const std::vector &sparse_ids); + int GetPackDimension(const std::vector &req_names, + const std::vector &sparse_ids); + int GetPackDimension(const Metadata::FlagCollection &flags, + const std::vector &sparse_ids); + int GetPackDimension(const std::vector &req_names); + int GetPackDimension(const Metadata::FlagCollection &flags); + std::size_t RefinementFuncID(const refinement::RefinementFunctions_t &funcs) const noexcept { return refinementFuncMaps_.funcs_to_ids.at(funcs); diff --git a/tst/unit/test_state_descriptor.cpp b/tst/unit/test_state_descriptor.cpp index 706f4c73c74d..95890f9caf11 100644 --- a/tst/unit/test_state_descriptor.cpp +++ b/tst/unit/test_state_descriptor.cpp @@ -121,6 +121,22 @@ TEST_CASE("Test Associate in StateDescriptor", "[StateDescriptor]") { } } +TEST_CASE("Test GetPackDimension in StateDescriptor", "[StateDescriptor]") { + GIVEN("Some flags and state descriptors") { + StateDescriptor state("state"); + WHEN("We add some fields with various shapes and total size") { + state.AddField("foo", Metadata(std::vector{}, std::vector{4})); + state.AddField("bar", + Metadata(std::vector{}, std::vector{4, 4})); + state.AddField("baz", + Metadata(std::vector{}, std::vector{4, 4, 4})); + THEN("The total length is identified correctly") { + REQUIRE(state.GetPackDimension(Metadata::GetUserFlag("state")) == 84); + } + } + } +} + TEST_CASE("Test dependency resolution in StateDescriptor", "[StateDescriptor]") { GIVEN("Some empty state descriptors and metadata") { // metadata