diff --git a/cpp/oneapi/dal/table/backend/csr_kernels.cpp b/cpp/oneapi/dal/table/backend/csr_kernels.cpp index 8e5aef236b2..d365168b847 100644 --- a/cpp/oneapi/dal/table/backend/csr_kernels.cpp +++ b/cpp/oneapi/dal/table/backend/csr_kernels.cpp @@ -14,6 +14,7 @@ * limitations under the License. *******************************************************************************/ +#include "oneapi/dal/backend/common.hpp" #include "oneapi/dal/table/backend/csr_kernels.hpp" #include "oneapi/dal/table/backend/convert.hpp" @@ -411,6 +412,10 @@ bool is_sorted(sycl::queue& queue, sycl::buffer count_buf(&count_descending_pairs, sycl::range<1>(1)); + const auto count_m1 = count - 1LL; + const auto wg_size = dal::backend::device_max_wg_size(queue); + const auto local_size = (wg_size < count_m1) ? wg_size : count_m1; + // count the number of pairs of the subsequent elements in the data array that are sorted // in desccending order using sycl::reduction queue @@ -419,9 +424,10 @@ bool is_sorted(sycl::queue& queue, auto count_descending_reduction = sycl::reduction(count_buf, cgh, sycl::ext::oneapi::plus()); - cgh.parallel_for(sycl::range<1>{ dal::detail::integral_cast(count - 1) }, + cgh.parallel_for(sycl::nd_range<1>{ count_m1, local_size }, count_descending_reduction, - [=](sycl::id<1> i, auto& count_descending) { + [=](sycl::nd_item<1> idx, auto& count_descending) { + const auto i = idx.get_global_id(0); if (data[i] > data[i + 1]) count_descending.combine(1); });