Skip to content

Commit

Permalink
Sort measurements before Ckf
Browse files Browse the repository at this point in the history
  • Loading branch information
beomki-yeo committed Oct 18, 2023
1 parent 05ad368 commit cef106d
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
// Vecmem include(s).
#include <vecmem/utils/copy.hpp>

// Thrust include(s).
#include <thrust/execution_policy.h>
#include <thrust/sort.h>

namespace traccc::cuda::experimental {

namespace {
Expand Down Expand Up @@ -141,6 +145,10 @@ clusterization_algorithm::output_type clusterization_algorithm::operator()(

m_stream.synchronize();

// Sort the measurements w.r.t geometry barcode
thrust::sort(thrust::cuda::par.on(stream), new_measurements_device.begin(),
new_measurements_device.end(), measurement_sort_comp());

return new_measurements_buffer;
}

Expand Down
32 changes: 10 additions & 22 deletions device/cuda/src/finding/finding_algorithm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -222,27 +222,16 @@ finding_algorithm<stepper_t, navigator_t>::operator()(
* Measurement Operations
*****************************************************************/

// Copy the measurements
measurement_collection_types::buffer sorted_measurements_buffer(
m_copy->get_size(measurements), m_mr.main);
measurement_collection_types::device sorted_measurements(
sorted_measurements_buffer);
measurement_collection_types::const_device measurements_device(
measurements);
thrust::copy(thrust::device, measurements_device.begin(),
measurements_device.end(), sorted_measurements.begin());

// Sort the measurements w.r.t geometry barcode
thrust::sort(thrust::device, sorted_measurements.begin(),
sorted_measurements.end(), measurement_sort_comp());

// Get copy of barcode uniques
measurement_collection_types::buffer uniques_buffer{
sorted_measurements.size(), m_mr.main};
measurements_device.size(), m_mr.main};
measurement_collection_types::device uniques(uniques_buffer);

measurement* end = thrust::unique_copy(
thrust::device, sorted_measurements.begin(), sorted_measurements.end(),
thrust::device, measurements_device.begin(), measurements_device.end(),
uniques.begin(), measurement_equal_comp());
unsigned int n_modules = end - uniques.begin();

Expand All @@ -251,8 +240,8 @@ finding_algorithm<stepper_t, navigator_t>::operator()(
m_mr.main};
vecmem::device_vector<unsigned int> upper_bounds(upper_bounds_buffer);

thrust::upper_bound(thrust::device, sorted_measurements.begin(),
sorted_measurements.end(), uniques.begin(),
thrust::upper_bound(thrust::device, measurements_device.begin(),
measurements_device.end(), uniques.begin(),
uniques.begin() + n_modules, upper_bounds.begin(),
measurement_sort_comp());

Expand All @@ -264,7 +253,7 @@ finding_algorithm<stepper_t, navigator_t>::operator()(
upper_bounds.end(), sizes.begin());

// Number of total measurements
const unsigned int n_total_measurements = sorted_measurements.size();
const unsigned int n_total_measurements = measurements_device.size();

/*****************************************************************
* Kernel1: Create barcode sequence
Expand Down Expand Up @@ -362,10 +351,9 @@ finding_algorithm<stepper_t, navigator_t>::operator()(
if (nBlocks > 0) {
kernels::find_tracks<detector_type, config_type>
<<<nBlocks, nThreads>>>(
m_cfg, det_view, sorted_measurements_buffer,
barcodes_buffer, upper_bounds_buffer, in_params_buffer,
n_threads_buffer, step,
(*global_counter_device).n_measurements_per_thread,
m_cfg, det_view, measurements, barcodes_buffer,
upper_bounds_buffer, in_params_buffer, n_threads_buffer,
step, (*global_counter_device).n_measurements_per_thread,
(*global_counter_device).n_total_threads,
updated_params_buffer, link_map[step],
(*global_counter_device).n_candidates);
Expand Down Expand Up @@ -507,8 +495,8 @@ finding_algorithm<stepper_t, navigator_t>::operator()(
nThreads = WARP_SIZE * 2;
nBlocks = (n_tips_total + nThreads - 1) / nThreads;
kernels::build_tracks<<<nBlocks, nThreads>>>(
sorted_measurements_buffer, seeds_buffer, links_buffer,
param_to_link_buffer, tips_buffer, track_candidates_buffer);
measurements, seeds_buffer, links_buffer, param_to_link_buffer,
tips_buffer, track_candidates_buffer);

CUDA_ERROR_CHECK(cudaGetLastError());
CUDA_ERROR_CHECK(cudaDeviceSynchronize());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,12 @@ clusterization_algorithm::output_type clusterization_algorithm::operator()(
sizeof(measurement) * (*num_measurements_host))
.wait_and_throw();

// @NOTE Uncomment once the onedpl is available
// oneapi::dpl::experimental::sort_async(
// oneapi::dpl::execution::dpcpp_default,
// new_measurements_device.begin(), new_measurements_device.end(),
// measurement_sort_comp());

return new_measurements_buffer;
}

Expand Down
3 changes: 3 additions & 0 deletions tests/cuda/test_ckf_sparse_tracks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ TEST_P(CkfSparseTrackTests, Run) {
traccc::measurement_collection_types::host& measurements_per_event =
readOut.measurements;

std::sort(measurements_per_event.begin(), measurements_per_event.end(),
measurement_sort_comp());

traccc::measurement_collection_types::buffer measurements_buffer(
measurements_per_event.size(), mr.main);
copy(vecmem::get_data(measurements_per_event), measurements_buffer);
Expand Down

0 comments on commit cef106d

Please sign in to comment.