Skip to content

Commit

Permalink
Only use detector input
Browse files Browse the repository at this point in the history
  • Loading branch information
beomki-yeo committed Oct 17, 2024
1 parent 96bef6b commit 5ee51aa
Show file tree
Hide file tree
Showing 23 changed files with 54 additions and 77 deletions.
4 changes: 2 additions & 2 deletions examples/io/create_binaries.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ int create_binaries(const traccc::opts::detector& detector_opts,
// Read the hits from the relevant event file
traccc::spacepoint_collection_types::host spacepoints{&host_mr};
traccc::io::read_spacepoints(spacepoints, event, input_opts.directory,
false, nullptr, input_opts.format);
nullptr, input_opts.format);

// Write binary file
traccc::io::write(event, output_opts.directory,
Expand All @@ -62,7 +62,7 @@ int create_binaries(const traccc::opts::detector& detector_opts,
// Read the measurements from the relevant event file
traccc::measurement_collection_types::host measurements{&host_mr};
traccc::io::read_measurements(measurements, event, input_opts.directory,
false, nullptr, input_opts.format);
nullptr, input_opts.format);

// Write binary file
traccc::io::write(event, output_opts.directory,
Expand Down
16 changes: 8 additions & 8 deletions examples/run/alpaka/seeding_example_alpaka.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,16 +172,16 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts,
traccc::performance::timer t("Hit reading (cpu)",
elapsedTimes);
// Read the hits from the relevant event file
traccc::io::read_spacepoints(spacepoints_per_event, event,
input_opts.directory,
input_opts.use_acts_geom_source,
&host_det, input_opts.format);
traccc::io::read_spacepoints(
spacepoints_per_event, event, input_opts.directory,
(input_opts.use_acts_geom_source ? &detector : nullptr),
input_opts.format);

// Read measurements
traccc::io::read_measurements(measurements_per_event, event,
input_opts.directory,
input_opts.use_acts_geom_source,
&host_det, input_opts.format);
traccc::io::read_measurements(
measurements_per_event, event, input_opts.directory,
(input_opts.use_acts_geom_source ? &detector : nullptr),
input_opts.format);
} // stop measuring hit reading timer

/*----------------------------
Expand Down
3 changes: 2 additions & 1 deletion examples/run/cpu/seeding_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts,
&host_mr};
traccc::io::read_spacepoints(
spacepoints_per_event, event, input_opts.directory,
input_opts.use_acts_geom_source, &detector, input_opts.format);
(input_opts.use_acts_geom_source ? &detector : nullptr),
input_opts.format);
n_spacepoints += spacepoints_per_event.size();

/*----------------
Expand Down
3 changes: 2 additions & 1 deletion examples/run/cpu/truth_finding_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ int seq_run(const traccc::opts::track_finding& finding_opts,
&host_mr};
traccc::io::read_measurements(
measurements_per_event, event, input_opts.directory,
input_opts.use_acts_geom_source, &detector, input_opts.format);
(input_opts.use_acts_geom_source ? &detector : nullptr),
input_opts.format);

// Run finding
auto track_candidates =
Expand Down
16 changes: 8 additions & 8 deletions examples/run/cuda/seeding_example_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,16 +224,16 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts,
traccc::performance::timer t("Hit reading (cpu)",
elapsedTimes);
// Read the hits from the relevant event file
traccc::io::read_spacepoints(spacepoints_per_event, event,
input_opts.directory,
input_opts.use_acts_geom_source,
&host_det, input_opts.format);
traccc::io::read_spacepoints(
spacepoints_per_event, event, input_opts.directory,
(input_opts.use_acts_geom_source ? &detector : nullptr),
input_opts.format);

// Read measurements
traccc::io::read_measurements(measurements_per_event, event,
input_opts.directory,
input_opts.use_acts_geom_source,
&host_det, input_opts.format);
traccc::io::read_measurements(
measurements_per_event, event,
(input_opts.use_acts_geom_source ? &detector : nullptr),
&host_det, input_opts.format);
} // stop measuring hit reading timer

/*----------------------------
Expand Down
3 changes: 2 additions & 1 deletion examples/run/cuda/truth_finding_example_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ int seq_run(const traccc::opts::track_finding& finding_opts,
mr.host};
traccc::io::read_measurements(
measurements_per_event, event, input_opts.directory,
input_opts.use_acts_geom_source, &detector, input_opts.format);
(input_opts.use_acts_geom_source ? &detector : nullptr),
input_opts.format);

traccc::measurement_collection_types::buffer measurements_cuda_buffer(
measurements_per_event.size(), mr.main);
Expand Down
8 changes: 4 additions & 4 deletions examples/run/kokkos/seeding_example_kokkos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,10 @@ int seq_run(const traccc::opts::track_seeding& seeding_opts,
traccc::performance::timer t("Hit reading (cpu)",
elapsedTimes);
// Read the hits from the relevant event file
traccc::io::read_spacepoints(spacepoints_per_event, event,
input_opts.directory,
input_opts.use_acts_geom_source,
&host_det, input_opts.format);
traccc::io::read_spacepoints(
spacepoints_per_event, event, input_opts.directory,
(input_opts.use_acts_geom_source ? &detector : nullptr),
input_opts.format);
} // stop measuring hit reading timer

{ // Spacepoin binning for kokkos
Expand Down
8 changes: 4 additions & 4 deletions examples/run/sycl/seeding_example_sycl.sycl
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,10 @@ int seq_run(const traccc::opts::detector& detector_opts,
traccc::performance::timer t("Hit reading (cpu)",
elapsedTimes);
// Read the hits from the relevant event file
traccc::io::read_spacepoints(spacepoints_per_event, event,
input_opts.directory,
input_opts.use_acts_geom_source,
&host_det, input_opts.format);
traccc::io::read_spacepoints(
spacepoints_per_event, event, input_opts.directory,
(input_opts.use_acts_geom_source ? &detector : nullptr),
input_opts.format);

} // stop measuring hit reading timer

Expand Down
4 changes: 0 additions & 4 deletions io/include/traccc/io/read_measurements.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,11 @@ namespace traccc::io {
/// @param[out] measurements The measurement collection to fill
/// @param[in] event The event ID to read in the measurements for
/// @param[in] directory The directory holding the measurement data files
/// @param[in] use_acts_geom_source Use acts geometry source
/// @param[in] detector detray detector
/// @param[in] format The format of the measurement data files (to read)
///
void read_measurements(measurement_collection_types::host& measurements,
std::size_t event, std::string_view directory,
bool use_acts_geom_source = true,
const traccc::default_detector::host* detector = nullptr,
data_format format = data_format::csv);

Expand All @@ -44,13 +42,11 @@ void read_measurements(measurement_collection_types::host& measurements,
///
/// @param[out] measurements The measurement collection to fill
/// @param[in] filename The file to read the measurement data from
/// @param[in] use_acts_geom_source Use acts geometry source
/// @param[in] detector detray detector
/// @param[in] format The format of the measurement data files (to read)
///
void read_measurements(measurement_collection_types::host& measurements,
std::string_view filename,
bool use_acts_geom_source = true,
const traccc::default_detector::host* detector = nullptr,
data_format format = data_format::csv);

Expand Down
4 changes: 0 additions & 4 deletions io/include/traccc/io/read_particles.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,11 @@ void read_particles(particle_collection_types::host &particles,
/// @param[in] event The event ID to read in the particles for
/// @param[in] directory The directory holding the particle data files
/// @param[in] format The format of the particle data files (to read)
/// @param[in] use_acts_geom_source Use acts geometry source
/// @param[in] detector detray detector
/// @param[in] filename_postfix Postfix for the particle file name(s)
///
void read_particles(particle_container_types::host &particles,
std::size_t event, std::string_view directory,
bool use_acts_geom_source = true,
const traccc::default_detector::host *detector = nullptr,
data_format format = data_format::csv,
std::string_view filename_postfix = "-particles_initial");
Expand All @@ -77,15 +75,13 @@ void read_particles(particle_container_types::host &particles,
/// @param[in] hits_file The file to read the simulated hits from
/// @param[in] measurements_file The file to read the "Acts measurements" from
/// @param[in] hit_map_file The file to read the hit->measurement mapping from
/// @param[in] use_acts_geom_source Use acts geometry source
/// @param[in] detector detray detector
/// @param[in] format The format of the particle data files (to read)
///
void read_particles(particle_container_types::host &particles,
std::string_view particles_file, std::string_view hits_file,
std::string_view measurements_file,
std::string_view hit_map_file,
bool use_acts_geom_source = true,
const traccc::default_detector::host *detector = nullptr,
data_format format = data_format::csv);

Expand Down
4 changes: 0 additions & 4 deletions io/include/traccc/io/read_spacepoints.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,11 @@ namespace traccc::io {
/// @param[out] spacepoints The spacepoint collection to fill
/// @param[in] event The event ID to read in the spacepoints for
/// @param[in] directory The directory holding the spacepoint data files
/// @param[in] use_acts_geom_source Use acts geometry source
/// @param[in] detector detray detector
/// @param[in] format The format of the data files (to read)
///
void read_spacepoints(spacepoint_collection_types::host& spacepoints,
std::size_t event, std::string_view directory,
bool use_acts_geom_source = true,
const traccc::default_detector::host* detector = nullptr,
data_format format = data_format::csv);

Expand All @@ -47,15 +45,13 @@ void read_spacepoints(spacepoint_collection_types::host& spacepoints,
/// @param[in] meas_filename The file to read the measurement data from
/// @param[in] meas_hit_map_filename The file to read the mapping from
/// measurements to hits from
/// @param[in] use_acts_geom_source Use acts geometry source
/// @param[in] detector detray detector
/// @param[in] format The format of the data files (to read)
///
void read_spacepoints(spacepoint_collection_types::host& spacepoints,
std::string_view hit_filename,
std::string_view meas_filename,
std::string_view meas_hit_map_filename,
bool use_acts_geom_source = true,
const traccc::default_detector::host* detector = nullptr,
data_format format = data_format::csv);

Expand Down
8 changes: 4 additions & 4 deletions io/src/csv/read_measurements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
namespace traccc::io::csv {

void read_measurements(measurement_collection_types::host& measurements,
std::string_view filename, bool use_acts_geom_source,
std::string_view filename,
const traccc::default_detector::host* detector,
const bool do_sort) {

Expand All @@ -26,7 +26,7 @@ void read_measurements(measurement_collection_types::host& measurements,
// For Acts data, build a map of acts->detray geometry IDs
std::map<geometry_id, geometry_id> acts_to_detray_id;

if (use_acts_geom_source && detector) {
if (detector) {
for (const auto& surface_desc : detector->surfaces()) {
acts_to_detray_id[surface_desc.source] =
surface_desc.barcode().value();
Expand All @@ -38,8 +38,8 @@ void read_measurements(measurement_collection_types::host& measurements,
while (reader.read(iomeas)) {

traccc::geometry_id geom_id = iomeas.geometry_id;
if (use_acts_geom_source && detector) {
geom_id = acts_to_detray_id[iomeas.geometry_id];
if (detector) {
geom_id = acts_to_detray_id.at(iomeas.geometry_id);
}

// Construct the measurement object.
Expand Down
2 changes: 0 additions & 2 deletions io/src/csv/read_measurements.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,11 @@ namespace traccc::io::csv {
///
/// @param[out] measurements The collection to fill with the measurement data
/// @param[in] filename The file to read the measurement data from
/// @param[in] use_acts_geom_source Use acts geometry source
/// @param[in] detector detray detector
/// @param[in] do_sort Whether to sort the measurements or not
///
void read_measurements(measurement_collection_types::host& measurements,
std::string_view filename,
bool use_acts_geom_source = true,
const traccc::default_detector::host* detector = nullptr,
const bool do_sort = true);

Expand Down
6 changes: 3 additions & 3 deletions io/src/csv/read_particles.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ void read_particles(particle_collection_types::host& particles,
void read_particles(particle_container_types::host& particles,
std::string_view particles_file, std::string_view hits_file,
std::string_view measurements_file,
std::string_view hit_map_file, bool use_acts_geom_source,
std::string_view hit_map_file,
const traccc::default_detector::host* detector) {

// Memory resource used by the temporary collections.
Expand All @@ -61,8 +61,8 @@ void read_particles(particle_container_types::host& particles,
// Read in all measurements, into a temporary collection.
static constexpr bool sort_measurements = false;
measurement_collection_types::host temp_measurements{&mr};
read_measurements(temp_measurements, measurements_file,
use_acts_geom_source, detector, sort_measurements);
read_measurements(temp_measurements, measurements_file, detector,
sort_measurements);

// Make a hit to measurement map.
std::unordered_map<std::size_t, std::size_t> hit_to_measurement;
Expand Down
3 changes: 1 addition & 2 deletions io/src/csv/read_particles.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,12 @@ void read_particles(particle_collection_types::host& particles,
/// @param[in] hits_file The file to read the simulated hits from
/// @param[in] measurements_file The file to read the "Acts measurements" from
/// @param[in] hit_map_file The file to read the hit->measurement mapping from
/// @param[in] use_acts_geom_source Use acts geometry source
/// @param[in] detector detray detector
///
void read_particles(particle_container_types::host& particles,
std::string_view particles_file, std::string_view hits_file,
std::string_view measurements_file,
std::string_view hit_map_file, bool use_acts_geom_source,
std::string_view hit_map_file,
const traccc::default_detector::host* detector);

} // namespace traccc::io::csv
4 changes: 1 addition & 3 deletions io/src/csv/read_spacepoints.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,12 @@ void read_spacepoints(spacepoint_collection_types::host& spacepoints,
std::string_view hit_filename,
std::string_view meas_filename,
std::string_view meas_hit_map_filename,
bool use_acts_geom_source,
const traccc::default_detector::host* detector) {

// Read all measurements.
measurement_collection_types::host measurements;
static constexpr bool sort_measurements = false;
read_measurements(measurements, meas_filename, use_acts_geom_source,
detector, sort_measurements);
read_measurements(measurements, meas_filename, detector, sort_measurements);

// Measurement hit id reader
auto mhid_reader =
Expand Down
2 changes: 0 additions & 2 deletions io/src/csv/read_spacepoints.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,12 @@ namespace traccc::io::csv {
/// @param[in] meas_filename The file to read the measurement data from
/// @param[in] meas_hit_map_filename The file to read the mapping from
/// measurements to hits from
/// @param[in] use_acts_geom_source Use acts geometry source
/// @param[in] detector detray detector
///
void read_spacepoints(spacepoint_collection_types::host& spacepoints,
std::string_view hit_filename,
std::string_view meas_filename,
std::string_view meas_hit_map_filename,
bool use_acts_geom_source = false,
const traccc::default_detector::host* detector = nullptr);

} // namespace traccc::io::csv
8 changes: 3 additions & 5 deletions io/src/read_measurements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ namespace traccc::io {

void read_measurements(measurement_collection_types::host& measurements,
std::size_t event, std::string_view directory,
bool use_acts_geom_source,
const traccc::default_detector::host* detector,
data_format format) {

Expand All @@ -31,7 +30,7 @@ void read_measurements(measurement_collection_types::host& measurements,
std::filesystem::path(get_event_filename(
event, "-measurements.csv")))
.native()),
use_acts_geom_source, detector, format);
detector, format);
break;
}
case data_format::binary: {
Expand All @@ -50,15 +49,14 @@ void read_measurements(measurement_collection_types::host& measurements,
}

void read_measurements(measurement_collection_types::host& measurements,
std::string_view filename, bool use_acts_geom_source,
std::string_view filename,
const traccc::default_detector::host* detector,
data_format format) {

static constexpr bool sort_measurements = true;
switch (format) {
case data_format::csv:
return csv::read_measurements(measurements, filename,
use_acts_geom_source, detector,
return csv::read_measurements(measurements, filename, detector,
sort_measurements);
default:
throw std::invalid_argument("Unsupported data format");
Expand Down
Loading

0 comments on commit 5ee51aa

Please sign in to comment.