diff --git a/cpp/include/legate_dataframe/core/library.hpp b/cpp/include/legate_dataframe/core/library.hpp index fc90d00..d6e9c0b 100644 --- a/cpp/include/legate_dataframe/core/library.hpp +++ b/cpp/include/legate_dataframe/core/library.hpp @@ -37,6 +37,7 @@ enum : int { ToTimestamps, ExtractTimestampComponent, Sequence, + Sort, GroupByAggregation }; } diff --git a/cpp/include/legate_dataframe/core/repartition_by_hash.hpp b/cpp/include/legate_dataframe/core/repartition_by_hash.hpp index 104d9c4..d13c287 100644 --- a/cpp/include/legate_dataframe/core/repartition_by_hash.hpp +++ b/cpp/include/legate_dataframe/core/repartition_by_hash.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * Copyright (c) 2023-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +23,12 @@ namespace legate::dataframe::task { +std::pair, + std::unique_ptr, cudf::table>>> +shuffle(GPUTaskContext& ctx, + std::vector& tbl_partitioned, + std::unique_ptr owning_table); + /** * @brief Repartition the table into hash table buckets for each rank/node. * diff --git a/cpp/include/legate_dataframe/sort.hpp b/cpp/include/legate_dataframe/sort.hpp new file mode 100644 index 0000000..2321c1c --- /dev/null +++ b/cpp/include/legate_dataframe/sort.hpp @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include + +namespace legate::dataframe { + +/** + * @brief Sort a logical table. + * + * Reorder the logical table so that the keys columns are sorted lexicographic + * based on their column_order and null_precedence. + * + * @param tbl The table to sort + * @param keys The column names to sort by. + * @param column_order Either ASCENDING or DESCENDING for each sort key/column. + * @param null_recedence Either BEFORE or AFTER for each sort key/column. + * AFTER means that nulls are considered larger and come last after an ascending + * and first after a descending sort. + * @return The sorted LogicalTable + */ +LogicalTable sort(const LogicalTable& tbl, + const std::vector& keys, + const std::vector& column_order, + const std::vector& null_precedence, + bool stable = false); + +} // namespace legate::dataframe diff --git a/cpp/src/core/library.cpp b/cpp/src/core/library.cpp index 326835a..d63c9a6 100644 --- a/cpp/src/core/library.cpp +++ b/cpp/src/core/library.cpp @@ -80,6 +80,10 @@ class Mapper : public legate::mapping::Mapper { // TODO: Join is identical to GroupBy, but we would have to look at // both input tables to ge the maximum column number. return std::nullopt; + case legate::dataframe::task::OpCode::Sort: + // Also similar to GroupBy, but does multiple shuffles and uses two + // additional helper columns + return std::nullopt; case legate::dataframe::task::OpCode::GroupByAggregation: { // Aggregation use repartitioning which uses ZCMEM for NCCL. // This depends on the number of columns (first scalar when storing @@ -88,6 +92,7 @@ class Mapper : public legate::mapping::Mapper { // No need for repartitioning, so no need for ZCMEM return 0; } + // Note: + 2 is for sorting! TODO: refactor into helper... auto num_cols = task.scalars().at(0).value(); auto nrank = task.get_launch_domain().get_volume(); diff --git a/cpp/src/core/repartition_by_hash.cu b/cpp/src/core/repartition_by_hash.cu index 9335841..b49f8e5 100644 --- a/cpp/src/core/repartition_by_hash.cu +++ b/cpp/src/core/repartition_by_hash.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * Copyright (c) 2023-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -118,19 +118,49 @@ class ExchangedSizes { } }; +} // namespace + /** - * @brief Shuffle (all-to-all exchange) packed cudf columns. + * @brief Shuffle (all-to-all exchange) packed cudf partitioned table. * * * @param ctx The context of the calling task - * @param columns A mapping of tasks to their packed columns. E.g. `columns.at(i)` - * will be send to the i'th task. NB: all tasks beside itself must have a map thus: - * `columns.size() == ctx.nranks - 1`. - * @return A new table containing "this nodes" unpacked columns. + * @param tbl_partitioned The local table partitioned into multiple tables such + * that `tbl_partitioned.at(i)` should end up at rank i. + * @param owning_table Optional table owning the data in `tbl_partitioned`. + * This table is cleaned up early to reduce the peak memory usage. + * If passed, `tbl_partitioned` is also cleared (as the content is invalid). + * @return An std::pair where the first entry contains a vector of table_view + * with all the chunks (including the local copy). The second entry contains + * a unique_ptr whose contents owns all parts. */ -std::pair, std::map> shuffle( - GPUTaskContext& ctx, const std::map& columns) +std::pair, + std::unique_ptr, cudf::table>>> +shuffle(GPUTaskContext& ctx, + std::vector& tbl_partitioned, + std::unique_ptr owning_table) { + if (tbl_partitioned.size() != ctx.nranks) { + throw std::runtime_error("internal error: partition split has wrong size."); + } + + // First we pack the columns into contiguous chunks for transfer/shuffling + // `columns.at(i)` will be send to the i'th task. + // N.B. all tasks beside itself have a map so `columns.size() == ctx.nranks - 1`. + std::map columns; + for (int i = 0; static_cast(i) < tbl_partitioned.size(); ++i) { + if (i != ctx.rank) { + columns[i] = cudf::detail::pack(tbl_partitioned[i], ctx.stream(), ctx.mr()); + } + } + // Also copy tbl_partitioned.at(ctx.rank). This copy is unnecessary but allows + // clearing the (possibly) much larger owning_table (if passed). + cudf::table local_table(tbl_partitioned.at(ctx.rank), ctx.stream(), ctx.mr()); + if (owning_table) { + tbl_partitioned.clear(); + owning_table.reset(); + } + assert(columns.size() == ctx.nranks - 1); ExchangedSizes sizes(ctx, columns); @@ -200,18 +230,25 @@ std::pair, std::map> shuf LEGATE_CHECK_CUDA(cudaStreamSynchronize(sizes.stream)); // Let's unpack and return the packed_columns received from our peers + // (and our own chunk so that `ret` is ordered for stable sorts) std::vector ret; - for (auto& [peer, buf] : recv_metadata) { + for (int peer = 0; peer < ctx.nranks; ++peer) { + if (peer == ctx.rank) { + ret.push_back(local_table.view()); + continue; + } uint8_t* gpu_data = nullptr; if (recv_gpu_data.count(peer)) { gpu_data = static_cast(recv_gpu_data.at(peer).data()); } - ret.push_back(cudf::unpack(buf.ptr(0), gpu_data)); + ret.push_back(cudf::unpack(recv_metadata.at(peer).ptr(0), gpu_data)); } - return std::make_pair(ret, std::move(recv_gpu_data)); -} -} // namespace + using owner_t = std::pair, cudf::table>; + return std::make_pair( + ret, + std::make_unique(std::make_pair(std::move(recv_gpu_data), std::move(local_table)))); +} std::unique_ptr repartition_by_hash( GPUTaskContext& ctx, @@ -223,8 +260,8 @@ std::unique_ptr repartition_by_hash( * 1) Each task split their local cudf table into `ctx.nranks` partitions based on the * hashing of `columns_to_hash` and assign each partition to a task. * 2) Each task pack (serialize) the partitions not assigned to itself. - * 3) All tasks exchange the sizes of their packed partitions and associated metadata. - * 4) All tasks shuffle (all-to-all exchange) the packed partitions. + * 4) All tasks shuffle (all-to-all exchange) the partitions. `shuffle` does this by first + * packing each partition into a contiguous memory block for the transfer. * 5) Each task unpack (deserialize) and concatenate the received columns with the self-assigned * partition. * 6) Finally, each task return a new local cudf table that contains the concatenated partitions. @@ -261,28 +298,9 @@ std::unique_ptr repartition_by_hash( tbl_partitioned = cudf::split(*partition_table, partition_offsets, ctx.stream()); } - if (tbl_partitioned.size() != ctx.nranks) { - throw std::runtime_error("internal error: partition split has wrong size."); - } - - // Pack and shuffle the columns - std::map packed_columns; - for (int i = 0; static_cast(i) < tbl_partitioned.size(); ++i) { - if (i != ctx.rank) { - packed_columns[i] = cudf::detail::pack(tbl_partitioned[i], ctx.stream(), ctx.mr()); - } - } - // Also copy tbl_partitioned.at(ctx.rank). This copy is unnecessary but allows - // clearing the (presumably) much larger partition_table. - cudf::table local_table(tbl_partitioned.at(ctx.rank), ctx.stream(), ctx.mr()); - tbl_partitioned.clear(); - partition_table.reset(); - auto [tables, buffers] = shuffle(ctx, packed_columns); - packed_columns.clear(); // Clear packed columns to preserve memory + auto [tables, owners] = shuffle(ctx, tbl_partitioned, std::move(partition_table)); - // Let's concatenate our own partition and all the partitioned received from the shuffle. - tables.push_back(local_table); return cudf::concatenate(tables, ctx.stream(), ctx.mr()); } diff --git a/cpp/src/sort.cpp b/cpp/src/sort.cpp new file mode 100644 index 0000000..db15ea8 --- /dev/null +++ b/cpp/src/sort.cpp @@ -0,0 +1,406 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#define DEBUG_SPLITS 0 +#if DEBUG_SPLITS +#include +#include +#endif + +namespace legate::dataframe { +namespace task { +namespace { + +/** + * @brief Return points at which to split a dataset. + * + * @param nvalues The total number of values to split. + * @param nsplits the number of splits (and split values as last is included) + * @param include_start Whether to include the starting 0. + * @returns cudf column selecting containing nsplits indices. + */ +std::unique_ptr get_split_ind(GPUTaskContext& ctx, + cudf::size_type nvalues, + int nsplits, + bool include_start) +{ + auto nvalues_per_split = nvalues / nsplits; + auto nvalues_left = nvalues - nvalues_per_split * nvalues; + if (nvalues_per_split == 0) { + nsplits = nvalues_left; // Only return non-empty splits + } + + std::vector split_values; + cudf::size_type split_offset = 0; + + if (include_start) { split_values.push_back(0); } + + for (cudf::size_type i = 0; i < nsplits - 1; i++) { + split_offset += nvalues_per_split; + if (i < nvalues_left) { split_offset += 1; } + + split_values.push_back(split_offset); + } + assert(split_offset += nvalues_per_split == nvalues); + +#if DEBUG_SPLITS + std::ostringstream splits_points_oss; + splits_points_oss << "Split points @" << nvalues << ": "; + for (auto point : split_values) { + splits_points_oss << point << ", "; + } + std::cout << splits_points_oss.str() << std::endl; +#endif + + auto ncopy = split_values.size(); + rmm::device_uvector split_ind(ncopy, ctx.stream(), ctx.mr()); + LEGATE_CHECK_CUDA(cudaMemcpyAsync(split_ind.data(), + split_values.data(), + ncopy * sizeof(cudf::size_type), + cudaMemcpyHostToDevice, + ctx.stream())); + LEGATE_CHECK_CUDA(cudaStreamSynchronize(ctx.stream())); + + return std::make_unique(std::move(split_ind), std::move(rmm::device_buffer()), 0); +} + +/* + * The practical way to do a distributed sort is to use the initial locally + * sorted table to estimate good split points to shuffle data to the final node. + * + * The rough approach for shuffling the data is the following: + * 1. Extract `nranks` split candidates from the local table and add their rank + * and local index. + * 2. Exchange all split candidate values and sort them + * 3. Again extract those candidates that evenly split the whole candidate set. + * (we do this on all nodes). + * 4. Shuffle the data based on the final split candidates. + * + * This approach is e.g. the same as in cupynumeric. We cannot guarantee balanced + * result chunk sizes, but it should ensure results are within 2x the input chunks. + * If all chunks are balanced and have the same distribution, the result will be + * (approximately) balanced again. + * + * The trickiest thing to take care of are equal values. Depending which rank + * the split point came from (i.e. where it is globally from us), we need to pick + * the split point inde (if ours) or the first equal value or just after the last + * depending on whether it came from an earlier or later rank. + */ +std::unique_ptr> find_splits_for_distribution( + GPUTaskContext& ctx, + const cudf::table_view& my_sorted_tbl, + const std::vector& keys_idx, + const std::vector& column_order, + const std::vector& null_precedence) +{ + /* + * Step 1: Extract local candidates and add rank and index information. + * + * We use the start index to find the value representing the range + * (used as a possible split value), but store the corresponding end of the + * the last step. + */ + auto my_split_ind_col = + get_split_ind(ctx, my_sorted_tbl.num_rows(), ctx.nranks, /* include_start */ true); + auto nsplits = my_split_ind_col->size(); + + auto my_split_rank_col = cudf::sequence(nsplits, + *cudf::make_fixed_width_scalar(int32_t{ctx.rank}), + *cudf::make_fixed_width_scalar(int32_t{0})); + + auto my_split_cols_tbl = cudf::gather(my_sorted_tbl.select(keys_idx), + my_split_ind_col->view(), + cudf::out_of_bounds_policy::DONT_CHECK, + ctx.stream(), + ctx.mr()); + + auto my_split_cols_view = my_split_cols_tbl->view(); + auto my_split_cols_vector = + std::vector(my_split_cols_view.begin(), my_split_cols_view.end()); + + // Add in rank and local index (together provide a global order). + my_split_cols_vector.push_back(my_split_rank_col->view()); + my_split_cols_vector.push_back(my_split_ind_col->view()); + auto my_splits = cudf::table_view(my_split_cols_vector); + + // keys(x) to pick columns from splits (which include rank and index): + std::vector value_keysx(keys_idx.size()); + std::iota(value_keysx.begin(), value_keysx.end(), 0); + std::vector all_keysx(keys_idx.size() + 2); + std::iota(all_keysx.begin(), all_keysx.end(), 0); + + /* + * Step 2: Share split candidates among all ranks. + */ + std::vector exchange_tables; + for (int i = 0; i < ctx.nranks; i++) { + exchange_tables.push_back(my_splits); + } + auto [split_candidates_shared, owners_split] = shuffle(ctx, exchange_tables, nullptr); + std::vector column_orderx(column_order); + std::vector null_precedencex(null_precedence); + column_orderx.insert(column_orderx.end(), {cudf::order::ASCENDING, cudf::order::ASCENDING}); + null_precedencex.insert(null_precedencex.end(), + {cudf::null_order::AFTER, cudf::null_order::AFTER}); + + // Merge is stable as it includes the rank and inde in the keys: + auto split_candidates = cudf::merge( + split_candidates_shared, all_keysx, column_orderx, null_precedencex, ctx.stream(), ctx.mr()); + owners_split.reset(); // copied into split_candidates + + /* + * Step 3: Find the best splitting points from all candidates + */ + auto split_value_inds = + get_split_ind(ctx, split_candidates->num_rows(), ctx.nranks, /* include_start */ false); + auto split_values_tbl = cudf::gather(split_candidates->view(), + split_value_inds->view(), + cudf::out_of_bounds_policy::DONT_CHECK, + ctx.stream(), + ctx.mr()); + auto split_values_view = split_values_tbl->view(); + + /* + * Step 4: Find the actual split points for the local dataset. + * + * We need to split based on the rank of the split point `split_rank` + * (i.e. where is the split point in the whole dataset): + * - if split_rank < my_rank: split at first equal row. + * - if split_rank == my_rank: use split-point index. + * - if split_rank > my_rank: split after last equal row + * + * N.B.: If this turns out to matter speed-wise, this can be spelled as a single + * `lower_bound` with the (global) row-index. A custom implementation could + * make that row-index a virtual table. + */ + auto split_candidates_first_col = cudf::lower_bound(my_sorted_tbl.select(keys_idx), + split_values_view.select(value_keysx), + column_order, + null_precedence, + ctx.stream(), + ctx.mr()); + auto split_candidates_first_view = split_candidates_first_col->view(); + auto split_candidates_last_col = cudf::upper_bound(my_sorted_tbl.select(keys_idx), + split_values_view.select(value_keysx), + column_order, + null_precedence, + ctx.stream(), + ctx.mr()); + auto split_candidates_last_view = split_candidates_last_col->view(); + + // The local index and rank of the split value, we'll use the rank if it came from this rank + auto split_candidates_equal_view = split_values_view.column(my_splits.num_columns() - 1); + auto split_candiates_rank_view = split_values_view.column(my_splits.num_columns() - 2); + + /* + * Copy all the above information to the host and finalize the local splits. + */ + auto nsplitpoints = ctx.nranks - 1; + std::vector split_candidates_first(nsplitpoints); + std::vector split_candidates_last(nsplitpoints); + std::vector split_candidates_equal(nsplitpoints); + std::vector split_candidates_rank(nsplitpoints); + + LEGATE_CHECK_CUDA(cudaMemcpyAsync(split_candidates_first.data(), + split_candidates_first_view.data(), + nsplitpoints * sizeof(cudf::size_type), + cudaMemcpyDeviceToHost, + ctx.stream())); + LEGATE_CHECK_CUDA(cudaMemcpyAsync(split_candidates_last.data(), + split_candidates_last_view.data(), + nsplitpoints * sizeof(cudf::size_type), + cudaMemcpyDeviceToHost, + ctx.stream())); + LEGATE_CHECK_CUDA(cudaMemcpyAsync(split_candidates_equal.data(), + split_candidates_equal_view.data(), + nsplitpoints * sizeof(cudf::size_type), + cudaMemcpyDeviceToHost, + ctx.stream())); + LEGATE_CHECK_CUDA(cudaMemcpyAsync(split_candidates_rank.data(), + split_candiates_rank_view.data(), + nsplitpoints * sizeof(int32_t), + cudaMemcpyDeviceToHost, + ctx.stream())); + + LEGATE_CHECK_CUDA(cudaStreamSynchronize(ctx.stream())); + + auto splits_host = std::make_unique>(); + for (int i = 0; i < nsplitpoints; i++) { + if (split_candidates_rank[i] < ctx.rank) { + splits_host->push_back(split_candidates_first[i]); + } else if (split_candidates_rank[i] > ctx.rank) { + splits_host->push_back(split_candidates_last[i]); + } else { + splits_host->push_back(split_candidates_equal[i]); + } + } + +#if DEBUG_SPLITS + std::ostringstream full_splits_oss; + full_splits_oss << "Final local split points @" << ctx.rank + << " (nrows=" << my_sorted_tbl.num_rows() << "):\n"; + for (int i = 0; i < nsplitpoints; i++) { + full_splits_oss << " " << splits_host->at(i) << ", split by r"; + full_splits_oss << split_candidates_rank[i] << ": "; + full_splits_oss << split_candidates_first[i] << "<" << split_candidates_last[i]; + full_splits_oss << ", r[ind]=" << split_candidates_equal[i] << "\n"; + } + std::cout << full_splits_oss.str() << std::endl; +#endif + return std::move(splits_host); +} + +} // namespace + +class SortTask : public Task { + public: + static void gpu_variant(legate::TaskContext context) + { + GPUTaskContext ctx{context}; + const auto tbl = argument::get_next_input(ctx); + const auto keys_idx = argument::get_next_scalar_vector(ctx); + const auto column_order = argument::get_next_scalar_vector(ctx); + const auto null_precedence = argument::get_next_scalar_vector(ctx); + const auto stable = argument::get_next_scalar(ctx); + auto output = argument::get_next_output(ctx); + + if (tbl.is_broadcasted() && ctx.rank != 1) { + // Note: It might be nice to just sort locally and keep it broadcast. + output.bind_empty_data(); + return; + } + + // Create a new locally sorted table (we always need this) + auto cudf_tbl = tbl.table_view(); + auto key = cudf_tbl.select(keys_idx); + auto sort_func = stable ? cudf::stable_sort_by_key : cudf::sort_by_key; + auto my_sorted_tbl = + sort_func(cudf_tbl, key, column_order, null_precedence, ctx.stream(), ctx.mr()); + + if (ctx.nranks == 1 || tbl.is_broadcasted()) { + output.move_into(my_sorted_tbl->release()); + return; + } + + auto split_indices = find_splits_for_distribution( + ctx, my_sorted_tbl->view(), keys_idx, column_order, null_precedence); + + auto partitions = cudf::split(my_sorted_tbl->view(), *split_indices, ctx.stream()); + auto [parts, owners] = shuffle(ctx, partitions, std::move(my_sorted_tbl)); + + std::unique_ptr result; + if (!stable) { + result = cudf::merge(parts, keys_idx, column_order, null_precedence, ctx.stream(), ctx.mr()); + } else { + // This is not good, but libcudf has no stable merge: + // https://github.com/rapidsai/cudf/issues/16010 + // https://github.com/rapidsai/cudf/issues/7379 + result = cudf::concatenate(parts, ctx.stream(), ctx.mr()); + owners.reset(); // we created a copy. + auto res_view = result->view(); + result = sort_func( + res_view, res_view.select(keys_idx), column_order, null_precedence, ctx.stream(), ctx.mr()); + } + +#if DEBUG_SPLITS + std::ostringstream result_size_oss; + result_size_oss << "Rank/chunk " << ctx.rank << " includes " << result->num_rows() + << " rows.\n"; + result_size_oss << " from individual chunks: "; + for (auto part : parts) { + result_size_oss << part.num_rows() << ", "; + } + std::cout << result_size_oss.str() << std::endl; +#endif + output.move_into(std::move(result)); + } +}; + +} // namespace task + +LogicalTable sort(const LogicalTable& tbl, + const std::vector& keys, + const std::vector& column_order, + const std::vector& null_precedence, + bool stable) +{ + if (keys.size() == 0) { throw std::invalid_argument("must sort along at least one column"); } + if (column_order.size() != keys.size() || null_precedence.size() != keys.size()) { + throw std::invalid_argument("sort column order and null precedence must match number of keys"); + } + + auto runtime = legate::Runtime::get_runtime(); + + auto ret = LogicalTable::empty_like(tbl); + + std::vector keys_idx(keys.size()); + std::vector> column_order_lg(keys.size()); + std::vector> null_precedence_lg(keys.size()); + + const auto& name_to_idx = tbl.get_column_names(); + for (size_t i = 0; i < keys.size(); i++) { + keys_idx[i] = name_to_idx.at(keys[i]); + column_order_lg[i] = static_cast>(column_order[i]); + null_precedence_lg[i] = static_cast>(null_precedence[i]); + } + + legate::AutoTask task = runtime->create_task(get_library(), task::SortTask::TASK_ID); + argument::add_next_input(task, tbl); + argument::add_next_scalar_vector(task, keys_idx); + argument::add_next_scalar_vector(task, column_order_lg); + argument::add_next_scalar_vector(task, null_precedence_lg); + argument::add_next_scalar(task, stable); + argument::add_next_output(task, ret); + + task.add_communicator("nccl"); + + runtime->submit(std::move(task)); + return ret; +} + +} // namespace legate::dataframe + +namespace { + +void __attribute__((constructor)) register_tasks() +{ + legate::dataframe::task::SortTask::register_variants(); +} + +} // namespace diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 150e694..8456ebf 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Copyright (c) 2023-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -39,8 +39,8 @@ target_compile_options(cpp_tests PRIVATE "$<$:${LDF_TEST_ # Note that fmt::fmt should not be required, but seems to be for debug builds. target_link_libraries( - cpp_tests PRIVATE LegateDataframe cudf cudf::cudftestutil cudf::cudftestutil_impl - GTest::gmock GTest::gtest fmt::fmt $ + cpp_tests PRIVATE LegateDataframe cudf cudf::cudftestutil cudf::cudftestutil_impl GTest::gmock + GTest::gtest fmt::fmt $ ) rapids_test_add( NAME cpp_tests diff --git a/cpp/tests/test_sort.cpp b/cpp/tests/test_sort.cpp new file mode 100644 index 0000000..db8fb9a --- /dev/null +++ b/cpp/tests/test_sort.cpp @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2024-2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include +#include +#include +#include + +#include + +template +using column_wrapper = cudf::test::fixed_width_column_wrapper; +using strcol_wrapper = cudf::test::strings_column_wrapper; +using CVector = std::vector>; +using Table = cudf::table; + +using Order = cudf::order; +using NullOrder = cudf::null_order; + +TEST(SortTest, SimpleNoNulls) +{ + column_wrapper col_0{{3, 1, 2, 0, 2, 2}}; + strcol_wrapper col_1({"s1", "s1", "s1", "s0", "s0", "s0"}); + column_wrapper col_2{{0, 1, 2, 3, 4, 5}}; + + CVector cols; + cols.push_back(col_0.release()); + cols.push_back(col_1.release()); + cols.push_back(col_2.release()); + + Table tbl(std::move(cols)); + + legate::dataframe::LogicalTable lg_tbl(tbl.view(), {"a", "b", "c"}); + + std::vector order({Order::ASCENDING, Order::ASCENDING, Order::ASCENDING}); + std::vector null_precedence({NullOrder::AFTER, NullOrder::AFTER, NullOrder::AFTER}); + bool stable = true; + + auto expect = cudf::stable_sort(tbl, order, null_precedence); + auto result = legate::dataframe::sort(lg_tbl, {"a", "b", "c"}, order, null_precedence, stable); + + CUDF_TEST_EXPECT_TABLES_EQUIVALENT(*expect, result.get_cudf()->view()); + + // Change first column to descending and use the unstable sort path: + order = {Order::DESCENDING, Order::ASCENDING, Order::ASCENDING}; + null_precedence = {NullOrder::AFTER, NullOrder::AFTER, NullOrder::AFTER}; + stable = false; + + expect = cudf::sort(tbl, order, null_precedence); + result = legate::dataframe::sort(lg_tbl, {"a", "b", "c"}, order, null_precedence, stable); + + CUDF_TEST_EXPECT_TABLES_EQUIVALENT(*expect, result.get_cudf()->view()); +} + +TEST(SortTest, SimpleNullsSomeCols) +{ + column_wrapper col_0{{3, 1, 2, 0, 2, 2}, {0, 0, 1, 1, 0, 0}}; + strcol_wrapper col_1({"s1", "s1", "s1", "s0", "s0", "s0"}, {0, 0, 0, 0, 0, 1}); + column_wrapper col_2{{0, 1, 2, 3, 4, 5}}; + + CVector cols; + cols.push_back(col_0.release()); + cols.push_back(col_1.release()); + cols.push_back(col_2.release()); + + Table tbl(std::move(cols)); + + legate::dataframe::LogicalTable lg_tbl(tbl.view(), {"a", "b", "c"}); + + std::vector order({Order::ASCENDING, Order::DESCENDING}); + std::vector null_precedence({NullOrder::BEFORE, NullOrder::AFTER}); + bool stable = true; + + auto expect = cudf::stable_sort_by_key(tbl, tbl.select({0, 1}), order, null_precedence); + auto result = legate::dataframe::sort(lg_tbl, {"a", "b"}, order, null_precedence, stable); + + CUDF_TEST_EXPECT_TABLES_EQUIVALENT(*expect, result.get_cudf()->view()); +} diff --git a/docs/source/api/table_funcs.rst b/docs/source/api/table_funcs.rst index a7b91dc..66836f6 100644 --- a/docs/source/api/table_funcs.rst +++ b/docs/source/api/table_funcs.rst @@ -7,6 +7,9 @@ Table functions .. autofunction:: legate_dataframe.lib.join.join +.. autofunction:: + legate_dataframe.lib.sort.sort + Related options/enums --------------------- @@ -27,3 +30,12 @@ Related options/enums .. autodata:: legate_dataframe.lib.join.null_equality :no-value: + +.. autodata:: legate_dataframe.lib.sort.Order + + Column sort order, either ``ASCENDING`` or ``DESCENDING`` (from ``pylibcudf``). + +.. autodata:: legate_dataframe.lib.sort.NullOrder + + NULL sort order with respect to values, either ``FIRST`` or ``LAST``. + I.e. whether NULL is considered smaller or larger any possible value. diff --git a/python/legate_dataframe/lib/CMakeLists.txt b/python/legate_dataframe/lib/CMakeLists.txt index b61aac9..dd5fc20 100644 --- a/python/legate_dataframe/lib/CMakeLists.txt +++ b/python/legate_dataframe/lib/CMakeLists.txt @@ -14,7 +14,7 @@ # Set the list of Cython files to build set(cython_sources binaryop.pyx csv.pyx groupby_aggregation.pyx join.pyx parquet.pyx replace.pyx - timestamps.pyx unaryop.pyx + sort.pyx timestamps.pyx unaryop.pyx ) rapids_cython_create_modules( diff --git a/python/legate_dataframe/lib/sort.pyi b/python/legate_dataframe/lib/sort.pyi new file mode 100644 index 0000000..12d7096 --- /dev/null +++ b/python/legate_dataframe/lib/sort.pyi @@ -0,0 +1,17 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from pylibcudf.types import NullOrder, Order + +from legate_dataframe.lib.core.table import LogicalTable + +__all__ = ["NullOrder", "Order", "sort"] + +def sort( + tbl: LogicalTable, + keys: list[str], + *, + column_order: list[Order] | None, + null_precedence: list[NullOrder] | None, + stable: bool, +) -> LogicalTable: ... diff --git a/python/legate_dataframe/lib/sort.pyx b/python/legate_dataframe/lib/sort.pyx new file mode 100644 index 0000000..edff9de --- /dev/null +++ b/python/legate_dataframe/lib/sort.pyx @@ -0,0 +1,83 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# distutils: language = c++ +# cython: language_level=3 + + +from libcpp cimport bool as cpp_bool +from libcpp.string cimport string +from libcpp.vector cimport vector + +from pylibcudf.types cimport null_order, order + +from legate_dataframe.lib.core.table cimport LogicalTable, cpp_LogicalTable + +from pylibcudf.types import NullOrder, Order + +from legate_dataframe.utils import _track_provenance + + +cdef extern from "" nogil: + cpp_LogicalTable cpp_sort "legate::dataframe::sort"( + const cpp_LogicalTable& tbl, + const vector[string]& keys, + const vector[order]& c_order, + const vector[null_order]& c_null_precedence, + cpp_bool stable, + ) except + + + +@_track_provenance +def sort( + LogicalTable tbl, + list keys, + *, + list column_order = None, + list null_precedence = None, + stable = False, +): + """Perform a sort of the table based on the given columns. + + Parameters + ---------- + tbl + The table to sort + keys + The column names to sort by. + column_order + An ``Order.ASCENDING`` or ``Order.DESCENDING`` for each key denoting the + final order for that column. Defaults to all ascending. + null_precedence + A ``NullOrder.BEFORE`` or ``NullOrder.AFTER`` for each key denoting if NULL + values are considered considered smaller (before) or larger (after) any + value. I.e. by default nulls are sorted "after" meaning they come + last after an ascending sort and first after a descending sort. + stable + Whether to perform a stable sort (default ``False``). Stable sort currently + uses a less efficient merge and may not perform as well as it should. + + Returns + ------- + A new sorted table. + + """ + cdef vector[string] keys_vector + cdef vector[order] c_orders + cdef vector[null_order] c_null_precedence + + if column_order is None: + c_orders = [Order.ASCENDING] * len(keys) + else: + c_orders = column_order + + if null_precedence is None: + c_null_precedence = [NullOrder.AFTER] * len(keys) + else: + c_null_precedence = null_precedence + + for k in keys: + keys_vector.push_back(k.encode('UTF-8')) + + return LogicalTable.from_handle( + cpp_sort(tbl._handle, keys_vector, c_orders, c_null_precedence, stable)) diff --git a/python/tests/test_sort.py b/python/tests/test_sort.py new file mode 100644 index 0000000..5a0302f --- /dev/null +++ b/python/tests/test_sort.py @@ -0,0 +1,205 @@ +# Copyright (c) 2025, NVIDIA CORPORATION +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import cudf +import cupy +import numpy as np +import pytest + +from legate_dataframe import LogicalTable +from legate_dataframe.lib.sort import NullOrder, Order, sort +from legate_dataframe.testing import assert_frame_equal + + +@pytest.mark.parametrize( + "values", + [ + cupy.arange(0, 1000), + cupy.arange(0, -1000, -1), + cupy.ones(1000), + cupy.ones(1), + cupy.random.randint(0, 1000, size=1000), + ], +) +def test_basic(values): + df = cudf.DataFrame({"a": values}) + + lg_df = LogicalTable.from_cudf(df) + lg_sorted = sort(lg_df, ["a"]) + + df_sorted = df.sort_values(by=["a"]) + + assert_frame_equal(lg_sorted, df_sorted) + + +@pytest.mark.parametrize( + "values,stable", + [ + (cupy.arange(0, 1000), False), + (cupy.arange(0, 1000), True), + (cupy.arange(0, -1000, -1), False), + (cupy.arange(0, -1000, -1), True), + (cupy.ones(1000), True), + (cupy.ones(3), True), + (cupy.random.randint(0, 1000, size=1000), True), + ], +) +def test_basic_with_extra_column(values, stable): + # Similar as above, but additional column should stay shuffle same. + df = cudf.DataFrame({"a": values, "b": cupy.arange(len(values))}) + + lg_df = LogicalTable.from_cudf(df) + lg_sorted = sort(lg_df, ["a"], stable=stable) + + if not stable: + df_sorted = df.sort_values(by=["a"]) + else: + df_sorted = df.sort_values(by=["a"], kind="stable") + + assert_frame_equal(lg_sorted, df_sorted) + + +@pytest.mark.parametrize("reversed", [True, False]) +def test_shifted_equal_window(reversed): + # The tricky part abort sorting are the exact splits for exchanging. + # assume we have at least two gpus/workders. Shift a window of 50 + # (i.e. half of each worker), through, to see if it gets split incorrectly. + for i in range(150): + before = cupy.arange(i) + constant = cupy.full(50, i) + after = cupy.arange(50 + i, 200) + values = cupy.concatenate([before, constant, after]) + if reversed: + values = values[::-1].copy() + + # Need a second column to check the splits: + df = cudf.DataFrame({"a": values, "b": cupy.arange(200)}) + + lg_df = LogicalTable.from_cudf(df) + lg_sorted = sort(lg_df, ["a"], stable=True) + df_sorted = df.sort_values(by=["a"], kind="stable") + + assert_frame_equal(lg_sorted, df_sorted) + + +@pytest.mark.parametrize("stable", [True, False]) +@pytest.mark.parametrize( + "by,ascending,nulls_last", + [ + (["a"], [True], True), # completely standard sort + (["a"], [False], False), + (["a", "b", "c"], [True, False, True], True), + (["c", "a", "b"], [True, False, True], False), + (["c", "b", "a"], [True, False, True], True), + ], +) +def test_orders(by, ascending, nulls_last, stable): + # Note that cudf/pandas don't allow passing na_position as a list. + np.random.seed(1) + + if not stable: + # If the sort is not stable, include index to have stable results... + by.append("idx") + ascending.append(True) + + # Generate a dataset with many repeats so all columns should matter + values_a = np.arange(10).repeat(100) + values_b = np.arange(10.0).repeat(100) + values_c = ["a", "b", "hello", "d", "e", "f", "e", "🙂", "e", "g"] * 100 + + np.random.shuffle(values_a) + np.random.shuffle(values_b) + series_a = cudf.Series(values_a).mask( + np.random.choice([True, False], size=1000, p=[0.1, 0.9]) + ) + series_b = cudf.Series(values_b).mask( + np.random.choice([True, False], size=1000, p=[0.1, 0.9]) + ) + series_c = cudf.Series(values_c).mask( + np.random.choice([True, False], size=1000, p=[0.1, 0.9]) + ) + + cudf_df = cudf.DataFrame( + { + "a": series_a, + "b": series_b, + "c": series_c, + "idx": cupy.arange(1000), + } + ) + lg_df = LogicalTable.from_cudf(cudf_df) + + kind = "stable" if stable else "quicksort" + na_position = "last" if nulls_last else "first" + expected = cudf_df.sort_values( + by=by, ascending=ascending, na_position=na_position, kind=kind + ) + + column_order = [Order.ASCENDING if a else Order.DESCENDING for a in ascending] + # If nulls are last they are considered "after" for an ascending sort, but + # if nulls come first they are considered "before"/smaller all values: + if nulls_last: + null_precedence = [ + NullOrder.AFTER if a else NullOrder.BEFORE for a in ascending + ] + else: + null_precedence = [ + NullOrder.BEFORE if a else NullOrder.AFTER for a in ascending + ] + + lg_sorted = sort( + lg_df, + keys=by, + column_order=column_order, + null_precedence=null_precedence, + stable=stable, + ) + + assert_frame_equal(lg_sorted, expected) + + +def test_na_position_explicit(): + cudf_df = cudf.DataFrame({"a": [0, 1, None, None], "b": [1, None, 0, None]}) + + lg_df = LogicalTable.from_cudf(cudf_df) + lg_sorted = sort( + lg_df, ["a", "b"], null_precedence=[NullOrder.BEFORE, NullOrder.AFTER] + ) + + expected = cudf.DataFrame({"a": [None, None, 0, 1], "b": [0, None, 1, None]}) + + assert_frame_equal(lg_sorted, expected) + + +@pytest.mark.parametrize( + "keys,column_order,null_precedence", + [ + ([], None, None), + (["bad_col", None, None]), + (["a"], [Order.ASCENDING] * 2, None), + (["a"], None, [NullOrder.BEFORE] * 2), + # These should fail (wrong enum passed), but cython doesn't check: + # (["a", "b"], [Order.ASCENDING] * 2, [Order.ASCENDING] * 2), + # (["a", "b"], [NullOrder.BEFORE] * 2, [NullOrder.BEFORE] * 2), + ], +) +def test_errors_incorrect_args(keys, column_order, null_precedence): + df = cudf.DataFrame({"a": [0, 1, 2, 3], "b": [0, 1, 2, 3]}) + lg_df = LogicalTable.from_cudf(df) + + with pytest.raises((ValueError, TypeError)): + sort( + lg_df, keys=keys, column_order=column_order, null_precedence=null_precedence + )