Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add overloads for reclassify that result in integral arrays #786

Merged
merged 1 commit into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion source/framework/algorithm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1384,7 +1384,7 @@ block()

foreach(Policies IN LISTS LUE_FRAMEWORK_ALGORITHM_POLICIES)
foreach(InputElement IN LISTS LUE_FRAMEWORK_INTEGRAL_ELEMENTS)
foreach(OutputElement IN LISTS LUE_FRAMEWORK_FLOATING_POINT_ELEMENTS)
foreach(OutputElement IN LISTS LUE_FRAMEWORK_ELEMENTS)
foreach(rank IN LISTS LUE_FRAMEWORK_RANKS)
math(EXPR count "${count} + 1")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ namespace lue {
namespace default_policies {

template<typename FromElement, typename ToElement, Rank rank>
PartitionedArray<ToElement, rank> reclassify(
auto reclassify(
PartitionedArray<FromElement, rank> const& array,
hpx::shared_future<LookupTable<FromElement, ToElement>> const& lookup_table)
-> PartitionedArray<ToElement, rank>
{
using Policies = policy::reclassify::DefaultPolicies<FromElement, ToElement>;

Expand All @@ -28,9 +29,9 @@ namespace lue {


template<typename FromElement, typename ToElement, Rank rank>
PartitionedArray<ToElement, rank> reclassify(
auto reclassify(
PartitionedArray<FromElement, rank> const& array,
LookupTable<FromElement, ToElement> const& lookup_table)
LookupTable<FromElement, ToElement> const& lookup_table) -> PartitionedArray<ToElement, rank>
{
using Policies = policy::reclassify::DefaultPolicies<FromElement, ToElement>;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ namespace lue {
namespace detail {

template<typename OutputPartition, typename Policies, typename InputPartition>
OutputPartition reclassify_partition_ready(
auto reclassify_partition_ready(
Policies const& policies,
InputPartition const& input_partition,
LookupTable<ElementT<InputPartition>, ElementT<OutputPartition>> const& lookup_table)
-> OutputPartition
{
using Offset = OffsetT<InputPartition>;
using InputData = DataT<InputPartition>;
Expand Down Expand Up @@ -68,11 +69,11 @@ namespace lue {


template<typename Policies, typename InputPartition, typename OutputPartition>
OutputPartition reclassify_partition(
auto reclassify_partition(
Policies const& policies,
InputPartition const& input_partition,
hpx::shared_future<LookupTable<ElementT<InputPartition>, ElementT<OutputPartition>>> const&
lookup_table)
lookup_table) -> OutputPartition
{
using FromElement = ElementT<InputPartition>;
using ToElement = ElementT<OutputPartition>;
Expand Down Expand Up @@ -105,10 +106,11 @@ namespace lue {


template<typename Policies, typename FromElement, typename ToElement, Rank rank>
PartitionedArray<ToElement, rank> reclassify(
auto reclassify(
Policies const& policies,
PartitionedArray<FromElement, rank> const& input_array,
hpx::shared_future<LookupTable<FromElement, ToElement>> const& lookup_table)
-> PartitionedArray<ToElement, rank>
{
// Spawn a task for each partition that will reclassify it

Expand All @@ -128,10 +130,10 @@ namespace lue {
InputPartitions const& input_partitions{input_array.partitions()};
OutputPartitions output_partitions{shape_in_partitions(input_array)};

for (Index p = 0; p < nr_partitions(input_array); ++p)
for (Index partition_idx = 0; partition_idx < nr_partitions(input_array); ++partition_idx)
{
output_partitions[p] =
hpx::async(action, localities[p], policies, input_partitions[p], lookup_table);
output_partitions[partition_idx] = hpx::async(
action, localities[partition_idx], policies, input_partitions[partition_idx], lookup_table);
}

return OutputArray{shape(input_array), localities, std::move(output_partitions)};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,18 @@ namespace lue {


template<typename Policies, typename FromElement, typename ToElement, Rank rank>
PartitionedArray<ToElement, rank> reclassify(
auto reclassify(
Policies const& policies,
PartitionedArray<FromElement, rank> const& array,
hpx::shared_future<LookupTable<FromElement, ToElement>> const& lookup_table);
hpx::shared_future<LookupTable<FromElement, ToElement>> const& lookup_table)
-> PartitionedArray<ToElement, rank>;


template<typename Policies, typename FromElement, typename ToElement, Rank rank>
PartitionedArray<ToElement, rank> reclassify(
auto reclassify(
Policies const& policies,
PartitionedArray<FromElement, rank> const& array,
LookupTable<FromElement, ToElement> const& lookup_table)
LookupTable<FromElement, ToElement> const& lookup_table) -> PartitionedArray<ToElement, rank>
{
return reclassify(policies, array, hpx::make_ready_future(lookup_table).share());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ namespace lue {
namespace value_policies {

template<typename FromElement, typename ToElement, Rank rank>
PartitionedArray<ToElement, rank> reclassify(
auto reclassify(
PartitionedArray<FromElement, rank> const& array,
hpx::shared_future<LookupTable<FromElement, ToElement>> const& lookup_table)
-> PartitionedArray<ToElement, rank>
{
using Policies = policy::reclassify::DefaultValuePolicies<FromElement, ToElement>;

Expand All @@ -28,9 +29,9 @@ namespace lue {


template<typename FromElement, typename ToElement, Rank rank>
PartitionedArray<ToElement, rank> reclassify(
auto reclassify(
PartitionedArray<FromElement, rank> const& array,
LookupTable<FromElement, ToElement> const& lookup_table)
LookupTable<FromElement, ToElement> const& lookup_table) -> PartitionedArray<ToElement, rank>
{
using Policies = policy::reclassify::DefaultValuePolicies<FromElement, ToElement>;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,45 +11,48 @@ using namespace pybind11::literals;
namespace lue::framework {
namespace {

template<typename FromElement, typename ToElement, Rank rank>
auto reclassify2(
PartitionedArray<FromElement, rank> const& array,
LookupTable<FromElement, ToElement> const& lookup_table) -> PartitionedArray<ToElement, rank>
template<typename ToElement, std::integral FromElement, std::floating_point LUTElement>
auto cast_lut(LookupTable<FromElement, LUTElement> const& lookup_table)
-> LookupTable<FromElement, ToElement>
{
return value_policies::reclassify(array, lookup_table);
}


template<typename ToElement2, typename FromElement, typename ToElement1>
auto cast_lut(LookupTable<FromElement, ToElement1> const& lookup_table)
-> LookupTable<FromElement, ToElement2>
{
static_assert(std::is_integral_v<FromElement>);
static_assert(std::is_floating_point_v<ToElement1>);
static_assert(std::is_floating_point_v<ToElement2>);

if constexpr (std::is_same_v<ToElement1, ToElement2>)
if constexpr (std::is_same_v<LUTElement, ToElement>)
{
return lookup_table;
}
else
{
LookupTable<FromElement, ToElement2> result;
LookupTable<FromElement, ToElement> result;

for (auto const& [key, value] : lookup_table)
{
result[key] = static_cast<ToElement2>(value);
result[key] = static_cast<ToElement>(value);
}

return result;
}
}


template<typename FromElement, typename ToElement, Rank rank>
template<Arithmetic ToElement, std::integral FromElement, std::floating_point LUTElement, Rank rank>
auto reclassify(
PartitionedArray<FromElement, rank> const& array,
LookupTable<FromElement, ToElement> const& lookup_table,
LookupTable<FromElement, LUTElement> const& lookup_table) -> pybind11::object
{
pybind11::object result{};

if constexpr (arithmetic_element_supported<ToElement>)
{
result = pybind11::cast(value_policies::reclassify(array, cast_lut<ToElement>(lookup_table)));
}

return result;
}


template<std::integral FromElement, std::floating_point LUTElement, Rank rank>
auto reclassify(
PartitionedArray<FromElement, rank> const& array,
LookupTable<FromElement, LUTElement> const& lookup_table,
pybind11::dtype const& dtype) -> pybind11::object
{
// Switch on dtype and call a function that returns an array of the
Expand All @@ -65,12 +68,19 @@ namespace lue::framework {
// Signed integer
switch (size)
{
case 1:
{
result = reclassify<std::int8_t>(array, lookup_table);
break;
}
case 4:
{
result = reclassify<std::int32_t>(array, lookup_table);
break;
}
case 8:
{
result = reclassify<std::int64_t>(array, lookup_table);
break;
}
}
Expand All @@ -84,14 +94,17 @@ namespace lue::framework {
{
case 1:
{
result = reclassify<std::uint8_t>(array, lookup_table);
break;
}
case 4:
{
result = reclassify<std::uint32_t>(array, lookup_table);
break;
}
case 8:
{
result = reclassify<std::uint64_t>(array, lookup_table);
break;
}
}
Expand All @@ -105,24 +118,12 @@ namespace lue::framework {
{
case 4:
{
using Element = float;

if constexpr (arithmetic_element_supported<Element>)
{
result = pybind11::cast(reclassify2(array, cast_lut<Element>(lookup_table)));
}

result = reclassify<float>(array, lookup_table);
break;
}
case 8:
{
using Element = double;

if constexpr (arithmetic_element_supported<Element>)
{
result = pybind11::cast(reclassify2(array, cast_lut<Element>(lookup_table)));
}

result = reclassify<double>(array, lookup_table);
break;
}
}
Expand All @@ -140,16 +141,6 @@ namespace lue::framework {
}


// template<typename FromElement, typename ToElement, Rank rank>
// auto reclassify(
// PartitionedArray<FromElement, rank> const& array,
// LookupTable<FromElement, ToElement> const& lookup_table,
// pybind11::object const& dtype_args) -> pybind11::object
// {
// return reclassify1(array, lookup_table, pybind11::dtype::from_args(dtype_args));
// }


class Binder
{

Expand All @@ -160,12 +151,12 @@ namespace lue::framework {
{
Rank const rank{2};
using FromElement = Element;
using ToElement = LargestFloatingPointElement;
using LUTElement = LargestFloatingPointElement;

module.def(
"reclassify",
[](PartitionedArray<FromElement, rank> const& array,
LookupTable<FromElement, ToElement> const& lookup_table,
LookupTable<FromElement, LUTElement> const& lookup_table,
pybind11::object const& dtype_args) // -> pybind11::object
{ return reclassify(array, lookup_table, pybind11::dtype::from_args(dtype_args)); },
"array"_a,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ def test_overloads(self):
for from_element_type in lfr.integral_element_types:
ids = lfr.create_array(array_shape, from_element_type, id_)

for to_element_type in lfr.floating_point_element_types:
for to_element_type in lfr.arithmetic_element_types:
lookup_table = {
1: 1.1,
2: 2.2,
3: 3.3,
4: 4.4,
1: 4,
2: 3,
3: 2,
4: 1,
}
array = lfr.reclassify(ids, lookup_table, dtype=to_element_type)

Expand Down
Loading