Skip to content

Commit

Permalink
Merge pull request acts-project#467 from beomki-yeo/sort-at-clusteriz…
Browse files Browse the repository at this point in the history
…ation

Sort measurements in CCA
  • Loading branch information
beomki-yeo authored Oct 20, 2023
2 parents 116f7bf + 1d0913c commit 929479e
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 39 deletions.
25 changes: 8 additions & 17 deletions core/include/traccc/finding/finding_algorithm.ipp
Original file line number Diff line number Diff line change
Expand Up @@ -29,30 +29,21 @@ finding_algorithm<stepper_t, navigator_t>::operator()(
* Measurement Operations
*****************************************************************/

// Copy the measurements
measurement_collection_types::host sorted_measurements = measurements;

// Sort the measurements w.r.t geometry barcode
std::sort(sorted_measurements.begin(), sorted_measurements.end(),
measurement_sort_comp());

// Get copy of barcode uniques
std::vector<measurement> uniques;
uniques.resize(sorted_measurements.size());
uniques.resize(measurements.size());

auto end =
std::unique_copy(sorted_measurements.begin(), sorted_measurements.end(),
uniques.begin(), measurement_equal_comp());
auto end = std::unique_copy(measurements.begin(), measurements.end(),
uniques.begin(), measurement_equal_comp());
unsigned int n_modules = end - uniques.begin();

// Get upper bounds of unique elements
std::vector<unsigned int> upper_bounds;
upper_bounds.reserve(n_modules);
for (unsigned int i = 0; i < n_modules; i++) {
auto up = std::upper_bound(sorted_measurements.begin(),
sorted_measurements.end(), uniques[i],
measurement_sort_comp());
upper_bounds.push_back(std::distance(sorted_measurements.begin(), up));
auto up = std::upper_bound(measurements.begin(), measurements.end(),
uniques[i], measurement_sort_comp());
upper_bounds.push_back(std::distance(measurements.begin(), up));
}

// Get the number of measurements of each module
Expand Down Expand Up @@ -178,7 +169,7 @@ finding_algorithm<stepper_t, navigator_t>::operator()(
bound_track_parameters bound_param(in_param.surface_link(),
in_param.vector(),
in_param.covariance());
const auto& meas = sorted_measurements[item_id];
const auto& meas = measurements[item_id];

track_state<transform3_type> trk_state(meas);

Expand Down Expand Up @@ -281,7 +272,7 @@ finding_algorithm<stepper_t, navigator_t>::operator()(

auto& cand = *it;

cand = sorted_measurements.at(L.meas_idx);
cand = measurements.at(L.meas_idx);

// Break the loop if the iterator is at the first candidate and
// fill the seed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
// Vecmem include(s).
#include <vecmem/utils/copy.hpp>

// Thrust include(s).
#include <thrust/execution_policy.h>
#include <thrust/sort.h>

namespace traccc::cuda::experimental {

namespace {
Expand Down Expand Up @@ -141,6 +145,10 @@ clusterization_algorithm::output_type clusterization_algorithm::operator()(

m_stream.synchronize();

// Sort the measurements w.r.t geometry barcode
thrust::sort(thrust::cuda::par.on(stream), new_measurements_device.begin(),
new_measurements_device.end(), measurement_sort_comp());

return new_measurements_buffer;
}

Expand Down
32 changes: 10 additions & 22 deletions device/cuda/src/finding/finding_algorithm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -225,27 +225,16 @@ finding_algorithm<stepper_t, navigator_t>::operator()(
* Measurement Operations
*****************************************************************/

// Copy the measurements
measurement_collection_types::buffer sorted_measurements_buffer(
m_copy->get_size(measurements), m_mr.main);
measurement_collection_types::device sorted_measurements(
sorted_measurements_buffer);
measurement_collection_types::const_device measurements_device(
measurements);
thrust::copy(thrust::device, measurements_device.begin(),
measurements_device.end(), sorted_measurements.begin());

// Sort the measurements w.r.t geometry barcode
thrust::sort(thrust::device, sorted_measurements.begin(),
sorted_measurements.end(), measurement_sort_comp());

// Get copy of barcode uniques
measurement_collection_types::buffer uniques_buffer{
sorted_measurements.size(), m_mr.main};
measurements_device.size(), m_mr.main};
measurement_collection_types::device uniques(uniques_buffer);

measurement* end = thrust::unique_copy(
thrust::device, sorted_measurements.begin(), sorted_measurements.end(),
thrust::device, measurements_device.begin(), measurements_device.end(),
uniques.begin(), measurement_equal_comp());
unsigned int n_modules = end - uniques.begin();

Expand All @@ -254,8 +243,8 @@ finding_algorithm<stepper_t, navigator_t>::operator()(
m_mr.main};
vecmem::device_vector<unsigned int> upper_bounds(upper_bounds_buffer);

thrust::upper_bound(thrust::device, sorted_measurements.begin(),
sorted_measurements.end(), uniques.begin(),
thrust::upper_bound(thrust::device, measurements_device.begin(),
measurements_device.end(), uniques.begin(),
uniques.begin() + n_modules, upper_bounds.begin(),
measurement_sort_comp());

Expand All @@ -267,7 +256,7 @@ finding_algorithm<stepper_t, navigator_t>::operator()(
upper_bounds.end(), sizes.begin());

// Number of total measurements
const unsigned int n_total_measurements = sorted_measurements.size();
const unsigned int n_total_measurements = measurements_device.size();

/*****************************************************************
* Kernel1: Create barcode sequence
Expand Down Expand Up @@ -365,10 +354,9 @@ finding_algorithm<stepper_t, navigator_t>::operator()(
if (nBlocks > 0) {
kernels::find_tracks<detector_type, config_type>
<<<nBlocks, nThreads>>>(
m_cfg, det_view, sorted_measurements_buffer,
barcodes_buffer, upper_bounds_buffer, in_params_buffer,
n_threads_buffer, step,
(*global_counter_device).n_measurements_per_thread,
m_cfg, det_view, measurements, barcodes_buffer,
upper_bounds_buffer, in_params_buffer, n_threads_buffer,
step, (*global_counter_device).n_measurements_per_thread,
(*global_counter_device).n_total_threads,
updated_params_buffer, link_map[step],
(*global_counter_device).n_candidates);
Expand Down Expand Up @@ -512,8 +500,8 @@ finding_algorithm<stepper_t, navigator_t>::operator()(
nThreads = WARP_SIZE * 2;
nBlocks = (n_tips_total + nThreads - 1) / nThreads;
kernels::build_tracks<<<nBlocks, nThreads>>>(
sorted_measurements_buffer, seeds_buffer, links_buffer,
param_to_link_buffer, tips_buffer, track_candidates_buffer);
measurements, seeds_buffer, links_buffer, param_to_link_buffer,
tips_buffer, track_candidates_buffer);

CUDA_ERROR_CHECK(cudaGetLastError());
CUDA_ERROR_CHECK(cudaDeviceSynchronize());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,12 @@ clusterization_algorithm::output_type clusterization_algorithm::operator()(
sizeof(measurement) * (*num_measurements_host))
.wait_and_throw();

// @NOTE Uncomment once the onedpl is available
// oneapi::dpl::experimental::sort_async(
// oneapi::dpl::execution::dpcpp_default,
// new_measurements_device.begin(), new_measurements_device.end(),
// measurement_sort_comp());

return new_measurements_buffer;
}

Expand Down
3 changes: 3 additions & 0 deletions io/src/csv/read_measurements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ void read_measurements(measurement_reader_output& out,

result_measurements.push_back(meas);
}

std::sort(result_measurements.begin(), result_measurements.end(),
measurement_sort_comp());
}

measurement_container_types::host read_measurements_container(
Expand Down

0 comments on commit 929479e

Please sign in to comment.