diff --git a/tests/common/tests/cca_test.hpp b/tests/common/tests/cca_test.hpp index dce8c429f6..0499cd05c7 100644 --- a/tests/common/tests/cca_test.hpp +++ b/tests/common/tests/cca_test.hpp @@ -47,6 +47,17 @@ inline traccc::clustering_config default_ccl_test_config() { return rv; } +inline traccc::clustering_config tiny_ccl_test_config() { + traccc::clustering_config rv; + + rv.threads_per_partition = 128; + rv.max_cells_per_thread = 1; + rv.target_cells_per_thread = 1; + rv.backup_size_multiplier = 16384; + + return rv; +} + class ConnectedComponentAnalysisTests : public traccc::tests::data_test, public testing::WithParamInterface< @@ -95,6 +106,23 @@ class ConnectedComponentAnalysisTests return out; } + inline static std::vector get_test_files_short(void) { + const std::vector> cases = { + {"trackml_like", 10}, + }; + std::vector out; + + for (const std::pair &c : cases) { + for (std::size_t i = 0; i < c.second; ++i) { + std::ostringstream ss; + ss << c.first << "_" << std::setfill('0') << std::setw(10) << i; + out.push_back(ss.str()); + } + } + + return out; + } + inline void test_connected_component_analysis(ParamType p) { cca_function_t f = std::get<0>(p); std::string file_prefix = std::get<1>(p); diff --git a/tests/cuda/test_cca.cpp b/tests/cuda/test_cca.cpp index 47cef13e81..2c515ef007 100644 --- a/tests/cuda/test_cca.cpp +++ b/tests/cuda/test_cca.cpp @@ -13,49 +13,54 @@ #include #include "tests/cca_test.hpp" +#include "traccc/clusterization/clustering_config.hpp" #include "traccc/cuda/clusterization/clusterization_algorithm.hpp" #include "traccc/cuda/utils/stream.hpp" namespace { -cca_function_t f = [](const traccc::cell_collection_types::host& cells, - const traccc::cell_module_collection_types::host& - modules) { - std::map> result; +cca_function_t get_f_with(traccc::clustering_config cfg) { + return [cfg](const traccc::cell_collection_types::host& cells, + const traccc::cell_module_collection_types::host& modules) { + std::map> + result; - vecmem::host_memory_resource host_mr; - traccc::cuda::stream stream; - vecmem::cuda::device_memory_resource device_mr; - vecmem::cuda::async_copy copy{stream.cudaStream()}; + vecmem::host_memory_resource host_mr; + traccc::cuda::stream stream; + vecmem::cuda::device_memory_resource device_mr; + vecmem::cuda::async_copy copy{stream.cudaStream()}; - traccc::cuda::clusterization_algorithm cc({device_mr}, copy, stream, - default_ccl_test_config()); + traccc::cuda::clusterization_algorithm cc({device_mr}, copy, stream, + cfg); - traccc::cell_collection_types::buffer cells_buffer{ - static_cast( - cells.size()), - device_mr}; - copy.setup(cells_buffer); - copy(vecmem::get_data(cells), cells_buffer)->ignore(); + traccc::cell_collection_types::buffer cells_buffer{ + static_cast( + cells.size()), + device_mr}; + copy.setup(cells_buffer); + copy(vecmem::get_data(cells), cells_buffer)->ignore(); - traccc::cell_module_collection_types::buffer modules_buffer{ - static_cast( - modules.size()), - device_mr}; - copy.setup(modules_buffer); - copy(vecmem::get_data(modules), modules_buffer)->ignore(); + traccc::cell_module_collection_types::buffer modules_buffer{ + static_cast< + traccc::cell_module_collection_types::buffer::size_type>( + modules.size()), + device_mr}; + copy.setup(modules_buffer); + copy(vecmem::get_data(modules), modules_buffer)->ignore(); - auto measurements_buffer = cc(cells_buffer, modules_buffer); - traccc::measurement_collection_types::host measurements{&host_mr}; - copy(measurements_buffer, measurements)->wait(); + auto measurements_buffer = cc(cells_buffer, modules_buffer); + traccc::measurement_collection_types::host measurements{&host_mr}; + copy(measurements_buffer, measurements)->wait(); - for (std::size_t i = 0; i < measurements.size(); i++) { - result[modules.at(measurements.at(i).module_link).surface_link.value()] - .push_back(measurements.at(i)); - } + for (std::size_t i = 0; i < measurements.size(); i++) { + result[modules.at(measurements.at(i).module_link) + .surface_link.value()] + .push_back(measurements.at(i)); + } - return result; -}; + return result; + }; +} } // namespace TEST_P(ConnectedComponentAnalysisTests, Run) { @@ -65,6 +70,14 @@ TEST_P(ConnectedComponentAnalysisTests, Run) { INSTANTIATE_TEST_SUITE_P( CUDAFastSvAlgorithm, ConnectedComponentAnalysisTests, ::testing::Combine( - ::testing::Values(f), + ::testing::Values(get_f_with(default_ccl_test_config())), ::testing::ValuesIn(ConnectedComponentAnalysisTests::get_test_files())), ConnectedComponentAnalysisTests::get_test_name); + +INSTANTIATE_TEST_SUITE_P( + CUDAFastSvAlgorithmWithScratch, ConnectedComponentAnalysisTests, + ::testing::Combine( + ::testing::Values(get_f_with(tiny_ccl_test_config())), + ::testing::ValuesIn( + ConnectedComponentAnalysisTests::get_test_files_short())), + ConnectedComponentAnalysisTests::get_test_name); diff --git a/tests/sycl/test_cca.sycl b/tests/sycl/test_cca.sycl index 22d0635707..a068fadb59 100644 --- a/tests/sycl/test_cca.sycl +++ b/tests/sycl/test_cca.sycl @@ -19,44 +19,48 @@ namespace { -cca_function_t f = [](const traccc::cell_collection_types::host& cells, - const traccc::cell_module_collection_types::host& - modules) { - std::map> result; +cca_function_t get_f_with(traccc::clustering_config cfg) { + return [cfg](const traccc::cell_collection_types::host& cells, + const traccc::cell_module_collection_types::host& modules) { + std::map> + result; - vecmem::host_memory_resource host_mr; - cl::sycl::queue queue; - vecmem::sycl::device_memory_resource device_mr; - vecmem::sycl::async_copy copy{&queue}; + vecmem::host_memory_resource host_mr; + cl::sycl::queue queue; + vecmem::sycl::device_memory_resource device_mr; + vecmem::sycl::async_copy copy{&queue}; - traccc::sycl::clusterization_algorithm cc({device_mr}, copy, &queue, - default_ccl_test_config()); + traccc::sycl::clusterization_algorithm cc({device_mr}, copy, &queue, + cfg); - traccc::cell_collection_types::buffer cells_buffer{ - static_cast( - cells.size()), - device_mr}; - copy.setup(cells_buffer); - copy(vecmem::get_data(cells), cells_buffer)->ignore(); + traccc::cell_collection_types::buffer cells_buffer{ + static_cast( + cells.size()), + device_mr}; + copy.setup(cells_buffer); + copy(vecmem::get_data(cells), cells_buffer)->ignore(); - traccc::cell_module_collection_types::buffer modules_buffer{ - static_cast( - modules.size()), - device_mr}; - copy.setup(modules_buffer); - copy(vecmem::get_data(modules), modules_buffer)->ignore(); + traccc::cell_module_collection_types::buffer modules_buffer{ + static_cast< + traccc::cell_module_collection_types::buffer::size_type>( + modules.size()), + device_mr}; + copy.setup(modules_buffer); + copy(vecmem::get_data(modules), modules_buffer)->ignore(); - auto measurements_buffer = cc(cells_buffer, modules_buffer); - traccc::measurement_collection_types::host measurements{&host_mr}; - copy(measurements_buffer, measurements)->wait(); + auto measurements_buffer = cc(cells_buffer, modules_buffer); + traccc::measurement_collection_types::host measurements{&host_mr}; + copy(measurements_buffer, measurements)->wait(); - for (std::size_t i = 0; i < measurements.size(); i++) { - result[modules.at(measurements.at(i).module_link).surface_link.value()] - .push_back(measurements.at(i)); - } + for (std::size_t i = 0; i < measurements.size(); i++) { + result[modules.at(measurements.at(i).module_link) + .surface_link.value()] + .push_back(measurements.at(i)); + } - return result; -}; + return result; + }; +} } // namespace TEST_P(ConnectedComponentAnalysisTests, Run) { @@ -66,6 +70,14 @@ TEST_P(ConnectedComponentAnalysisTests, Run) { INSTANTIATE_TEST_SUITE_P( SYCLFastSvAlgorithm, ConnectedComponentAnalysisTests, ::testing::Combine( - ::testing::Values(f), + ::testing::Values(get_f_with(default_ccl_test_config())), ::testing::ValuesIn(ConnectedComponentAnalysisTests::get_test_files())), ConnectedComponentAnalysisTests::get_test_name); + +INSTANTIATE_TEST_SUITE_P( + SYCLFastSvAlgorithmWithScratch, ConnectedComponentAnalysisTests, + ::testing::Combine( + ::testing::Values(get_f_with(tiny_ccl_test_config())), + ::testing::ValuesIn( + ConnectedComponentAnalysisTests::get_test_files_short())), + ConnectedComponentAnalysisTests::get_test_name);