Skip to content

Commit

Permalink
Update detray to v45
Browse files Browse the repository at this point in the history
  • Loading branch information
beomki-yeo committed Oct 20, 2023
1 parent 05ad368 commit 58da02a
Show file tree
Hide file tree
Showing 46 changed files with 490 additions and 485 deletions.
5 changes: 4 additions & 1 deletion core/include/traccc/finding/finding_algorithm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ template <typename stepper_t, typename navigator_t>
class finding_algorithm
: public algorithm<track_candidate_container_types::host(
const typename navigator_t::detector_type&,
const typename stepper_t::magnetic_field_type&,
const measurement_collection_types::host&,
const bound_track_parameters_collection_types::host&)> {

Expand Down Expand Up @@ -72,6 +73,8 @@ class finding_algorithm
using interactor_type =
detray::pointwise_material_interactor<transform3_type>;

using bfield_type = typename stepper_t::magnetic_field_type;

public:
/// Configuration type
using config_type = finding_config<scalar_type>;
Expand All @@ -91,7 +94,7 @@ class finding_algorithm
/// @param measurements Input measurements
/// @param seeds Input seeds
track_candidate_container_types::host operator()(
const detector_type& det,
const detector_type& det, const bfield_type& field,
const measurement_collection_types::host& measurements,
const bound_track_parameters_collection_types::host& seeds) const;

Expand Down
4 changes: 2 additions & 2 deletions core/include/traccc/finding/finding_algorithm.ipp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace traccc {
template <typename stepper_t, typename navigator_t>
track_candidate_container_types::host
finding_algorithm<stepper_t, navigator_t>::operator()(
const detector_type& det,
const detector_type& det, const bfield_type& field,
const measurement_collection_types::host& measurements,
const bound_track_parameters_collection_types::host& seeds) const {

Expand Down Expand Up @@ -207,7 +207,7 @@ finding_algorithm<stepper_t, navigator_t>::operator()(

// Create propagator state
typename propagator_type::state propagation(
trk_state.filtered(), det.get_bfield(), det);
trk_state.filtered(), field, det);
propagation._stepping.template set_constraint<
detray::step::constraint::e_accuracy>(
m_cfg.constrained_step_size);
Expand Down
5 changes: 4 additions & 1 deletion core/include/traccc/fitting/fitting_algorithm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ template <typename fitter_t>
class fitting_algorithm
: public algorithm<track_state_container_types::host(
const typename fitter_t::detector_type&,
const typename fitter_t::bfield_type&,
const typename track_candidate_container_types::host&)> {

public:
using transform3_type = typename fitter_t::transform3_type;
using bfield_type = typename fitter_t::bfield_type;
/// Configuration type
using config_type = typename fitter_t::config_type;

Expand All @@ -40,10 +42,11 @@ class fitting_algorithm
/// @return the container of the fitted track parameters
track_state_container_types::host operator()(
const typename fitter_t::detector_type& det,
const typename fitter_t::bfield_type& field,
const typename track_candidate_container_types::host& track_candidates)
const override {

fitter_t fitter(det, m_cfg);
fitter_t fitter(det, field, m_cfg);

track_state_container_types::host output_states;

Expand Down
14 changes: 10 additions & 4 deletions core/include/traccc/fitting/kalman_filter/kalman_fitter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ class kalman_fitter {
// Detector type
using detector_type = typename navigator_t::detector_type;

// Field type
using bfield_type = typename stepper_t::magnetic_field_type;

// Actor types
using aborter = detray::pathlimit_aborter;
using transporter = detray::parameter_transporter<transform3_type>;
Expand All @@ -73,8 +76,9 @@ class kalman_fitter {
///
/// @param det the detector object
TRACCC_HOST_DEVICE
kalman_fitter(const detector_type& det, const config_type& cfg)
: m_detector(det), m_cfg(cfg) {}
kalman_fitter(const detector_type& det, const bfield_type& field,
const config_type& cfg)
: m_detector(det), m_field(field), m_cfg(cfg) {}

/// Kalman fitter state
struct state {
Expand Down Expand Up @@ -163,8 +167,7 @@ class kalman_fitter {

// Create propagator state
typename propagator_type::state propagation(
seed_params, m_detector.get_bfield(), m_detector,
std::move(nav_candidates));
seed_params, m_field, m_detector, std::move(nav_candidates));

// @TODO: Should be removed once detray is fixed to set the
// volume in the constructor
Expand Down Expand Up @@ -244,6 +247,9 @@ class kalman_fitter {
private:
// Detector object
const detector_type& m_detector;
// Field object
const bfield_type m_field;

// Configuration object
config_type m_cfg;
};
Expand Down
12 changes: 6 additions & 6 deletions core/include/traccc/utils/seed_generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ struct seed_generator {
seed_generator(const detector_t& det,
const std::array<scalar, e_bound_size>& stddevs,
const std::size_t sd = 0)
: m_detector(std::make_unique<detector_t>(det)), m_stddevs(stddevs) {
: m_detector(det), m_stddevs(stddevs) {
generator.seed(sd);
}

Expand All @@ -49,7 +49,7 @@ struct seed_generator {
const free_track_parameters& free_param) {

// Get bound parameter
const detray::surface<detector_t> sf{*m_detector, surface_link};
const detray::surface<detector_t> sf{m_detector, surface_link};

const cxt_t ctx{};
auto bound_vec = sf.free_to_bound_vector(ctx, free_param.vector());
Expand All @@ -68,10 +68,10 @@ struct seed_generator {
detray::pointwise_material_interactor<transform3_type>;

intersection_type sfi;
sfi.sf_desc = m_detector->surface(surface_link);
sfi.sf_desc = m_detector.surface(surface_link);
sf.template visit_mask<detray::intersection_update>(
detray::detail::ray<transform3_type>(free_param.vector()), sfi,
m_detector->transform_store());
m_detector.transform_store());

// Apply interactor
typename interactor_type::state interactor_state;
Expand Down Expand Up @@ -100,8 +100,8 @@ struct seed_generator {
std::random_device rd{};
std::mt19937 generator{rd()};

/// Detector objects
std::unique_ptr<detector_t> m_detector;
// Detector object
const detector_t& m_detector;
/// Standard deviations for parameter smearing
std::array<scalar, e_bound_size> m_stddevs;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace traccc::device {
///
template <typename detector_t>
TRACCC_DEVICE inline void apply_interaction(
std::size_t globalIndex, typename detector_t::detector_view_type det_data,
std::size_t globalIndex, typename detector_t::view_type det_data,
vecmem::data::jagged_vector_view<detray::intersection2D<
typename detector_t::surface_type, typename detector_t::transform3>>
nav_candidates_buffer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ namespace traccc::device {
template <typename detector_t, typename config_t>
TRACCC_DEVICE inline void find_tracks(
std::size_t globalIndex, const config_t cfg,
typename detector_t::detector_view_type det_data,
typename detector_t::view_type det_data,
measurement_collection_types::const_view measurements_view,
vecmem::data::vector_view<const detray::geometry::barcode> barcodes_view,
vecmem::data::vector_view<const unsigned int> upper_bounds_view,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace traccc::device {

template <typename detector_t>
TRACCC_DEVICE inline void apply_interaction(
std::size_t globalIndex, typename detector_t::detector_view_type det_data,
std::size_t globalIndex, typename detector_t::view_type det_data,
vecmem::data::jagged_vector_view<detray::intersection2D<
typename detector_t::surface_type, typename detector_t::transform3>>
nav_candidates_buffer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace traccc::device {
template <typename detector_t, typename config_t>
TRACCC_DEVICE inline void find_tracks(
std::size_t globalIndex, const config_t cfg,
typename detector_t::detector_view_type det_data,
typename detector_t::view_type det_data,
measurement_collection_types::const_view measurements_view,
vecmem::data::vector_view<const detray::geometry::barcode> barcodes_view,
vecmem::data::vector_view<const unsigned int> upper_bounds_view,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@

namespace traccc::device {

template <typename propagator_t, typename config_t>
template <typename propagator_t, typename bfield_t, typename config_t>
TRACCC_DEVICE inline void propagate_to_next_surface(
std::size_t globalIndex, const config_t cfg,
typename propagator_t::detector_type::detector_view_type det_data,
typename propagator_t::detector_type::view_type det_data,
bfield_t field_data,
vecmem::data::jagged_vector_view<typename propagator_t::intersection_type>
nav_candidates_buffer,
bound_track_parameters_collection_types::const_view in_params_view,
Expand Down Expand Up @@ -60,8 +61,7 @@ TRACCC_DEVICE inline void propagate_to_next_surface(

// Create propagator state
typename propagator_t::state propagation(
in_par, det.get_bfield(), det,
std::move(nav_candidates.at(globalIndex)));
in_par, field_data, det, std::move(nav_candidates.at(globalIndex)));
propagation._stepping
.template set_constraint<detray::step::constraint::e_accuracy>(
cfg.constrained_step_size);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@ namespace traccc::device {
/// @param[out] tips_view Tip link container for the current step
/// @param[out] n_out_params The number of output parameters
///
template <typename propagator_t, typename config_t>
template <typename propagator_t, typename bfield_t, typename config_t>
TRACCC_DEVICE inline void propagate_to_next_surface(
std::size_t globalIndex, const config_t cfg,
typename propagator_t::detector_type::detector_view_type det_data,
typename propagator_t::detector_type::view_type det_data,
bfield_t field_data,
vecmem::data::jagged_vector_view<typename propagator_t::intersection_type>
nav_candidates_buffer,
bound_track_parameters_collection_types::const_view in_params_view,
Expand Down
1 change: 1 addition & 0 deletions device/common/include/traccc/fitting/device/fit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ namespace traccc::device {
template <typename fitter_t, typename detector_view_t>
TRACCC_HOST_DEVICE inline void fit(
std::size_t globalIndex, detector_view_t det_data,
const typename fitter_t::bfield_type field_data,
const typename fitter_t::config_type cfg,
vecmem::data::jagged_vector_view<typename fitter_t::intersection_type>
nav_candidates_buffer,
Expand Down
3 changes: 2 additions & 1 deletion device/common/include/traccc/fitting/device/impl/fit.ipp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace traccc::device {
template <typename fitter_t, typename detector_view_t>
TRACCC_HOST_DEVICE inline void fit(
std::size_t globalIndex, detector_view_t det_data,
const typename fitter_t::bfield_type field_data,
const typename fitter_t::config_type cfg,
vecmem::data::jagged_vector_view<typename fitter_t::intersection_type>
nav_candidates_buffer,
Expand All @@ -28,7 +29,7 @@ TRACCC_HOST_DEVICE inline void fit(

track_state_container_types::device track_states(track_states_view);

fitter_t fitter(det, cfg);
fitter_t fitter(det, field_data, cfg);

if (globalIndex >= track_states.size()) {
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ namespace traccc::device::experimental {
///
template <typename detector_t>
TRACCC_HOST_DEVICE inline void form_spacepoints(
const std::size_t globalIndex,
typename detector_t::detector_view_type det_data,
const std::size_t globalIndex, typename detector_t::view_type det_data,
measurement_collection_types::const_view measurements_view,
spacepoint_collection_types::view spacepoints_view);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ namespace traccc::device::experimental {
///
template <typename detector_t>
TRACCC_HOST_DEVICE inline void form_spacepoints(
const std::size_t globalIndex,
typename detector_t::detector_view_type det_data,
const std::size_t globalIndex, typename detector_t::view_type det_data,
measurement_collection_types::const_view measurements_view,
spacepoint_collection_types::view spacepoints_view) {

Expand Down
9 changes: 7 additions & 2 deletions device/cuda/include/traccc/cuda/finding/finding_algorithm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ namespace traccc::cuda {
template <typename stepper_t, typename navigator_t>
class finding_algorithm
: public algorithm<track_candidate_container_types::buffer(
const typename navigator_t::detector_type::detector_view_type&,
const typename navigator_t::detector_type::view_type&,
const typename stepper_t::magnetic_field_type&,
const vecmem::data::jagged_vector_view<
typename navigator_t::intersection_type>&,
const typename measurement_collection_types::view&,
Expand All @@ -49,6 +50,9 @@ class finding_algorithm
/// Detector type
using detector_type = typename navigator_t::detector_type;

/// Field type
using bfield_type = typename stepper_t::magnetic_field_type;

/// Actor types
using interactor = detray::pointwise_material_interactor<transform3_type>;

Expand Down Expand Up @@ -85,7 +89,8 @@ class finding_algorithm
/// @param navigation_buffer Buffer for navigation candidates
/// @param seeds Input seeds
track_candidate_container_types::buffer operator()(
const typename detector_type::detector_view_type& det_view,
const typename detector_type::view_type& det_view,
const bfield_type& field_view,
const vecmem::data::jagged_vector_view<
typename navigator_t::intersection_type>& navigation_buffer,
const typename measurement_collection_types::view& measurements,
Expand Down
6 changes: 4 additions & 2 deletions device/cuda/include/traccc/cuda/fitting/fitting_algorithm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ namespace traccc::cuda {
template <typename fitter_t>
class fitting_algorithm
: public algorithm<track_state_container_types::buffer(
const typename fitter_t::detector_type::detector_view_type&,
const typename fitter_t::detector_type::view_type&,
const typename fitter_t::bfield_type&,
const vecmem::data::jagged_vector_view<
typename fitter_t::intersection_type>&,
const typename track_candidate_container_types::const_view&)> {
Expand All @@ -45,7 +46,8 @@ class fitting_algorithm

/// Run the algorithm
track_state_container_types::buffer operator()(
const typename fitter_t::detector_type::detector_view_type& det_view,
const typename fitter_t::detector_type::view_type& det_view,
const typename fitter_t::bfield_type& field_view,
const vecmem::data::jagged_vector_view<
typename fitter_t::intersection_type>& navigation_buffer,
const typename track_candidate_container_types::const_view&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace traccc::cuda::experimental {
template <typename detector_t>
class spacepoint_formation
: public algorithm<spacepoint_collection_types::buffer(
const typename detector_t::detector_view_type&,
const typename detector_t::view_type&,
const measurement_collection_types::const_view&)> {

public:
Expand All @@ -50,7 +50,7 @@ class spacepoint_formation
/// @param measurements a collection of measurements
/// @return a spacepoint collection (buffer)
spacepoint_collection_types::buffer operator()(
const typename detector_t::detector_view_type& det_view,
const typename detector_t::view_type& det_view,
const measurement_collection_types::const_view& measurements_view)
const override;

Expand Down
Loading

0 comments on commit 58da02a

Please sign in to comment.