From a5e4322bd14e98852b76c50c263b2a4d0bcf95b3 Mon Sep 17 00:00:00 2001 From: Kor de Jong Date: Tue, 21 Jan 2025 15:30:06 +0100 Subject: [PATCH] Add reclassify overloads --- source/framework/algorithm/CMakeLists.txt | 2 +- .../algorithm/default_policies/reclassify.hpp | 7 +- .../algorithm/definition/reclassify.hpp | 16 ++-- .../lue/framework/algorithm/reclassify.hpp | 9 +- .../algorithm/value_policies/reclassify.hpp | 7 +- .../algorithm/global_operation/reclassify.cpp | 85 +++++++++---------- .../global_operation/reclassify_test.py | 10 +-- 7 files changed, 66 insertions(+), 70 deletions(-) diff --git a/source/framework/algorithm/CMakeLists.txt b/source/framework/algorithm/CMakeLists.txt index cb4625be5..ada6b356f 100644 --- a/source/framework/algorithm/CMakeLists.txt +++ b/source/framework/algorithm/CMakeLists.txt @@ -1384,7 +1384,7 @@ block() foreach(Policies IN LISTS LUE_FRAMEWORK_ALGORITHM_POLICIES) foreach(InputElement IN LISTS LUE_FRAMEWORK_INTEGRAL_ELEMENTS) - foreach(OutputElement IN LISTS LUE_FRAMEWORK_FLOATING_POINT_ELEMENTS) + foreach(OutputElement IN LISTS LUE_FRAMEWORK_ELEMENTS) foreach(rank IN LISTS LUE_FRAMEWORK_RANKS) math(EXPR count "${count} + 1") diff --git a/source/framework/algorithm/include/lue/framework/algorithm/default_policies/reclassify.hpp b/source/framework/algorithm/include/lue/framework/algorithm/default_policies/reclassify.hpp index 38efb6dd4..ab7ceb9f1 100644 --- a/source/framework/algorithm/include/lue/framework/algorithm/default_policies/reclassify.hpp +++ b/source/framework/algorithm/include/lue/framework/algorithm/default_policies/reclassify.hpp @@ -17,9 +17,10 @@ namespace lue { namespace default_policies { template - PartitionedArray reclassify( + auto reclassify( PartitionedArray const& array, hpx::shared_future> const& lookup_table) + -> PartitionedArray { using Policies = policy::reclassify::DefaultPolicies; @@ -28,9 +29,9 @@ namespace lue { template - PartitionedArray reclassify( + auto reclassify( PartitionedArray const& array, - LookupTable const& lookup_table) + LookupTable const& lookup_table) -> PartitionedArray { using Policies = policy::reclassify::DefaultPolicies; diff --git a/source/framework/algorithm/include/lue/framework/algorithm/definition/reclassify.hpp b/source/framework/algorithm/include/lue/framework/algorithm/definition/reclassify.hpp index 0f68165da..d27c7a481 100644 --- a/source/framework/algorithm/include/lue/framework/algorithm/definition/reclassify.hpp +++ b/source/framework/algorithm/include/lue/framework/algorithm/definition/reclassify.hpp @@ -8,10 +8,11 @@ namespace lue { namespace detail { template - OutputPartition reclassify_partition_ready( + auto reclassify_partition_ready( Policies const& policies, InputPartition const& input_partition, LookupTable, ElementT> const& lookup_table) + -> OutputPartition { using Offset = OffsetT; using InputData = DataT; @@ -68,11 +69,11 @@ namespace lue { template - OutputPartition reclassify_partition( + auto reclassify_partition( Policies const& policies, InputPartition const& input_partition, hpx::shared_future, ElementT>> const& - lookup_table) + lookup_table) -> OutputPartition { using FromElement = ElementT; using ToElement = ElementT; @@ -105,10 +106,11 @@ namespace lue { template - PartitionedArray reclassify( + auto reclassify( Policies const& policies, PartitionedArray const& input_array, hpx::shared_future> const& lookup_table) + -> PartitionedArray { // Spawn a task for each partition that will reclassify it @@ -128,10 +130,10 @@ namespace lue { InputPartitions const& input_partitions{input_array.partitions()}; OutputPartitions output_partitions{shape_in_partitions(input_array)}; - for (Index p = 0; p < nr_partitions(input_array); ++p) + for (Index partition_idx = 0; partition_idx < nr_partitions(input_array); ++partition_idx) { - output_partitions[p] = - hpx::async(action, localities[p], policies, input_partitions[p], lookup_table); + output_partitions[partition_idx] = hpx::async( + action, localities[partition_idx], policies, input_partitions[partition_idx], lookup_table); } return OutputArray{shape(input_array), localities, std::move(output_partitions)}; diff --git a/source/framework/algorithm/include/lue/framework/algorithm/reclassify.hpp b/source/framework/algorithm/include/lue/framework/algorithm/reclassify.hpp index d6c1edd6a..4f968c7f6 100644 --- a/source/framework/algorithm/include/lue/framework/algorithm/reclassify.hpp +++ b/source/framework/algorithm/include/lue/framework/algorithm/reclassify.hpp @@ -11,17 +11,18 @@ namespace lue { template - PartitionedArray reclassify( + auto reclassify( Policies const& policies, PartitionedArray const& array, - hpx::shared_future> const& lookup_table); + hpx::shared_future> const& lookup_table) + -> PartitionedArray; template - PartitionedArray reclassify( + auto reclassify( Policies const& policies, PartitionedArray const& array, - LookupTable const& lookup_table) + LookupTable const& lookup_table) -> PartitionedArray { return reclassify(policies, array, hpx::make_ready_future(lookup_table).share()); } diff --git a/source/framework/algorithm/include/lue/framework/algorithm/value_policies/reclassify.hpp b/source/framework/algorithm/include/lue/framework/algorithm/value_policies/reclassify.hpp index 3ce385137..824fdc77f 100644 --- a/source/framework/algorithm/include/lue/framework/algorithm/value_policies/reclassify.hpp +++ b/source/framework/algorithm/include/lue/framework/algorithm/value_policies/reclassify.hpp @@ -17,9 +17,10 @@ namespace lue { namespace value_policies { template - PartitionedArray reclassify( + auto reclassify( PartitionedArray const& array, hpx::shared_future> const& lookup_table) + -> PartitionedArray { using Policies = policy::reclassify::DefaultValuePolicies; @@ -28,9 +29,9 @@ namespace lue { template - PartitionedArray reclassify( + auto reclassify( PartitionedArray const& array, - LookupTable const& lookup_table) + LookupTable const& lookup_table) -> PartitionedArray { using Policies = policy::reclassify::DefaultValuePolicies; diff --git a/source/framework/python/source/algorithm/global_operation/reclassify.cpp b/source/framework/python/source/algorithm/global_operation/reclassify.cpp index 33422c8ee..a3921a809 100644 --- a/source/framework/python/source/algorithm/global_operation/reclassify.cpp +++ b/source/framework/python/source/algorithm/global_operation/reclassify.cpp @@ -11,34 +11,21 @@ using namespace pybind11::literals; namespace lue::framework { namespace { - template - auto reclassify2( - PartitionedArray const& array, - LookupTable const& lookup_table) -> PartitionedArray + template + auto cast_lut(LookupTable const& lookup_table) + -> LookupTable { - return value_policies::reclassify(array, lookup_table); - } - - - template - auto cast_lut(LookupTable const& lookup_table) - -> LookupTable - { - static_assert(std::is_integral_v); - static_assert(std::is_floating_point_v); - static_assert(std::is_floating_point_v); - - if constexpr (std::is_same_v) + if constexpr (std::is_same_v) { return lookup_table; } else { - LookupTable result; + LookupTable result; for (auto const& [key, value] : lookup_table) { - result[key] = static_cast(value); + result[key] = static_cast(value); } return result; @@ -46,10 +33,26 @@ namespace lue::framework { } - template + template auto reclassify( PartitionedArray const& array, - LookupTable const& lookup_table, + LookupTable const& lookup_table) -> pybind11::object + { + pybind11::object result{}; + + if constexpr (arithmetic_element_supported) + { + result = pybind11::cast(value_policies::reclassify(array, cast_lut(lookup_table))); + } + + return result; + } + + + template + auto reclassify( + PartitionedArray const& array, + LookupTable const& lookup_table, pybind11::dtype const& dtype) -> pybind11::object { // Switch on dtype and call a function that returns an array of the @@ -65,12 +68,19 @@ namespace lue::framework { // Signed integer switch (size) { + case 1: + { + result = reclassify(array, lookup_table); + break; + } case 4: { + result = reclassify(array, lookup_table); break; } case 8: { + result = reclassify(array, lookup_table); break; } } @@ -84,14 +94,17 @@ namespace lue::framework { { case 1: { + result = reclassify(array, lookup_table); break; } case 4: { + result = reclassify(array, lookup_table); break; } case 8: { + result = reclassify(array, lookup_table); break; } } @@ -105,24 +118,12 @@ namespace lue::framework { { case 4: { - using Element = float; - - if constexpr (arithmetic_element_supported) - { - result = pybind11::cast(reclassify2(array, cast_lut(lookup_table))); - } - + result = reclassify(array, lookup_table); break; } case 8: { - using Element = double; - - if constexpr (arithmetic_element_supported) - { - result = pybind11::cast(reclassify2(array, cast_lut(lookup_table))); - } - + result = reclassify(array, lookup_table); break; } } @@ -140,16 +141,6 @@ namespace lue::framework { } - // template - // auto reclassify( - // PartitionedArray const& array, - // LookupTable const& lookup_table, - // pybind11::object const& dtype_args) -> pybind11::object - // { - // return reclassify1(array, lookup_table, pybind11::dtype::from_args(dtype_args)); - // } - - class Binder { @@ -160,12 +151,12 @@ namespace lue::framework { { Rank const rank{2}; using FromElement = Element; - using ToElement = LargestFloatingPointElement; + using LUTElement = LargestFloatingPointElement; module.def( "reclassify", [](PartitionedArray const& array, - LookupTable const& lookup_table, + LookupTable const& lookup_table, pybind11::object const& dtype_args) // -> pybind11::object { return reclassify(array, lookup_table, pybind11::dtype::from_args(dtype_args)); }, "array"_a, diff --git a/source/framework/python/test/algorithm/global_operation/reclassify_test.py b/source/framework/python/test/algorithm/global_operation/reclassify_test.py index a5eebca66..1e2d22e77 100644 --- a/source/framework/python/test/algorithm/global_operation/reclassify_test.py +++ b/source/framework/python/test/algorithm/global_operation/reclassify_test.py @@ -14,12 +14,12 @@ def test_overloads(self): for from_element_type in lfr.integral_element_types: ids = lfr.create_array(array_shape, from_element_type, id_) - for to_element_type in lfr.floating_point_element_types: + for to_element_type in lfr.arithmetic_element_types: lookup_table = { - 1: 1.1, - 2: 2.2, - 3: 3.3, - 4: 4.4, + 1: 4, + 2: 3, + 3: 2, + 4: 1, } array = lfr.reclassify(ids, lookup_table, dtype=to_element_type)