diff --git a/cpp/oneapi/dal/table/backend/csr_kernels.cpp b/cpp/oneapi/dal/table/backend/csr_kernels.cpp index 8e5aef236b2..7bf510318bd 100644 --- a/cpp/oneapi/dal/table/backend/csr_kernels.cpp +++ b/cpp/oneapi/dal/table/backend/csr_kernels.cpp @@ -14,6 +14,7 @@ * limitations under the License. *******************************************************************************/ +#include "oneapi/dal/backend/common.hpp" #include "oneapi/dal/table/backend/csr_kernels.hpp" #include "oneapi/dal/table/backend/convert.hpp" @@ -411,6 +412,12 @@ bool is_sorted(sycl::queue& queue, sycl::buffer count_buf(&count_descending_pairs, sycl::range<1>(1)); + const auto count_m1 = count - 1LL; + const auto wg_size = dal::backend::device_max_wg_size(queue); + const size_t count_m1_unsigned = static_cast(count_m1); + + const size_t wg_count = (count_m1 + wg_size - 1) / wg_size; + // count the number of pairs of the subsequent elements in the data array that are sorted // in desccending order using sycl::reduction queue @@ -419,10 +426,11 @@ bool is_sorted(sycl::queue& queue, auto count_descending_reduction = sycl::reduction(count_buf, cgh, sycl::ext::oneapi::plus()); - cgh.parallel_for(sycl::range<1>{ dal::detail::integral_cast(count - 1) }, + cgh.parallel_for(sycl::nd_range<1>{ wg_count * wg_size, wg_size }, count_descending_reduction, - [=](sycl::id<1> i, auto& count_descending) { - if (data[i] > data[i + 1]) + [=](sycl::nd_item<1> idx, auto& count_descending) { + const auto i = idx.get_global_id(0); + if (i < count_m1_unsigned && data[i + 1] < data[i]) count_descending.combine(1); }); })