Skip to content

Commit

Permalink
Backup
Browse files Browse the repository at this point in the history
  • Loading branch information
beomki-yeo committed Oct 19, 2023
1 parent 9ef6690 commit 4eaa933
Show file tree
Hide file tree
Showing 29 changed files with 401 additions and 379 deletions.
5 changes: 4 additions & 1 deletion core/include/traccc/finding/finding_algorithm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ template <typename stepper_t, typename navigator_t>
class finding_algorithm
: public algorithm<track_candidate_container_types::host(
const typename navigator_t::detector_type&,
const typename stepper_t::magnetic_field_type&,
const measurement_collection_types::host&,
const bound_track_parameters_collection_types::host&)> {

Expand Down Expand Up @@ -72,6 +73,8 @@ class finding_algorithm
using interactor_type =
detray::pointwise_material_interactor<transform3_type>;

using bfield_type = typename stepper_t::magnetic_field_type;

public:
/// Configuration type
using config_type = finding_config<scalar_type>;
Expand All @@ -91,7 +94,7 @@ class finding_algorithm
/// @param measurements Input measurements
/// @param seeds Input seeds
track_candidate_container_types::host operator()(
const detector_type& det,
const detector_type& det, const bfield_type& field,
const measurement_collection_types::host& measurements,
const bound_track_parameters_collection_types::host& seeds) const;

Expand Down
4 changes: 2 additions & 2 deletions core/include/traccc/finding/finding_algorithm.ipp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace traccc {
template <typename stepper_t, typename navigator_t>
track_candidate_container_types::host
finding_algorithm<stepper_t, navigator_t>::operator()(
const detector_type& det,
const detector_type& det, const bfield_type& field,
const measurement_collection_types::host& measurements,
const bound_track_parameters_collection_types::host& seeds) const {

Expand Down Expand Up @@ -207,7 +207,7 @@ finding_algorithm<stepper_t, navigator_t>::operator()(

// Create propagator state
typename propagator_type::state propagation(
trk_state.filtered(), det.get_bfield(), det);
trk_state.filtered(), field, det);
propagation._stepping.template set_constraint<
detray::step::constraint::e_accuracy>(
m_cfg.constrained_step_size);
Expand Down
5 changes: 4 additions & 1 deletion core/include/traccc/fitting/fitting_algorithm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ template <typename fitter_t>
class fitting_algorithm
: public algorithm<track_state_container_types::host(
const typename fitter_t::detector_type&,
const typename fitter_t::bfield_type&,
const typename track_candidate_container_types::host&)> {

public:
using transform3_type = typename fitter_t::transform3_type;
using bfield_type = typename fitter_t::bfield_type;
/// Configuration type
using config_type = typename fitter_t::config_type;

Expand All @@ -40,10 +42,11 @@ class fitting_algorithm
/// @return the container of the fitted track parameters
track_state_container_types::host operator()(
const typename fitter_t::detector_type& det,
const typename fitter_t::bfield_type& field,
const typename track_candidate_container_types::host& track_candidates)
const override {

fitter_t fitter(det, m_cfg);
fitter_t fitter(det, field, m_cfg);

track_state_container_types::host output_states;

Expand Down
14 changes: 10 additions & 4 deletions core/include/traccc/fitting/kalman_filter/kalman_fitter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ class kalman_fitter {
// Detector type
using detector_type = typename navigator_t::detector_type;

// Field type
using bfield_type = typename stepper_t::magnetic_field_type;

// Actor types
using aborter = detray::pathlimit_aborter;
using transporter = detray::parameter_transporter<transform3_type>;
Expand All @@ -73,8 +76,9 @@ class kalman_fitter {
///
/// @param det the detector object
TRACCC_HOST_DEVICE
kalman_fitter(const detector_type& det, const config_type& cfg)
: m_detector(det), m_cfg(cfg) {}
kalman_fitter(const detector_type& det, const bfield_type& field,
const config_type& cfg)
: m_detector(det), m_field(field), m_cfg(cfg) {}

/// Kalman fitter state
struct state {
Expand Down Expand Up @@ -163,8 +167,7 @@ class kalman_fitter {

// Create propagator state
typename propagator_type::state propagation(
seed_params, m_detector.get_bfield(), m_detector,
std::move(nav_candidates));
seed_params, m_field, m_detector, std::move(nav_candidates));

// @TODO: Should be removed once detray is fixed to set the
// volume in the constructor
Expand Down Expand Up @@ -244,6 +247,9 @@ class kalman_fitter {
private:
// Detector object
const detector_type& m_detector;
// Field object
const bfield_type& m_field;

// Configuration object
config_type m_cfg;
};
Expand Down
12 changes: 6 additions & 6 deletions core/include/traccc/utils/seed_generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ struct seed_generator {
seed_generator(const detector_t& det,
const std::array<scalar, e_bound_size>& stddevs,
const std::size_t sd = 0)
: m_detector(std::make_unique<detector_t>(det)), m_stddevs(stddevs) {
: m_detector(det), m_stddevs(stddevs) {
generator.seed(sd);
}

Expand All @@ -49,7 +49,7 @@ struct seed_generator {
const free_track_parameters& free_param) {

// Get bound parameter
const detray::surface<detector_t> sf{*m_detector, surface_link};
const detray::surface<detector_t> sf{m_detector, surface_link};

const cxt_t ctx{};
auto bound_vec = sf.free_to_bound_vector(ctx, free_param.vector());
Expand All @@ -68,10 +68,10 @@ struct seed_generator {
detray::pointwise_material_interactor<transform3_type>;

intersection_type sfi;
sfi.sf_desc = m_detector->surface(surface_link);
sfi.sf_desc = m_detector.surface(surface_link);
sf.template visit_mask<detray::intersection_update>(
detray::detail::ray<transform3_type>(free_param.vector()), sfi,
m_detector->transform_store());
m_detector.transform_store());

// Apply interactor
typename interactor_type::state interactor_state;
Expand Down Expand Up @@ -100,8 +100,8 @@ struct seed_generator {
std::random_device rd{};
std::mt19937 generator{rd()};

/// Detector objects
std::unique_ptr<detector_t> m_detector;
// Detector object
const detector_t& m_detector;
/// Standard deviations for parameter smearing
std::array<scalar, e_bound_size> m_stddevs;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@

namespace traccc::device {

template <typename propagator_t, typename config_t>
template <typename propagator_t, typename bfield_t, typename config_t>
TRACCC_DEVICE inline void propagate_to_next_surface(
std::size_t globalIndex, const config_t cfg,
typename propagator_t::detector_type::view_type det_data,
bfield_t field_data,
vecmem::data::jagged_vector_view<typename propagator_t::intersection_type>
nav_candidates_buffer,
bound_track_parameters_collection_types::const_view in_params_view,
Expand Down Expand Up @@ -60,8 +61,7 @@ TRACCC_DEVICE inline void propagate_to_next_surface(

// Create propagator state
typename propagator_t::state propagation(
in_par, det.get_bfield(), det,
std::move(nav_candidates.at(globalIndex)));
in_par, field_data, det, std::move(nav_candidates.at(globalIndex)));
propagation._stepping
.template set_constraint<detray::step::constraint::e_accuracy>(
cfg.constrained_step_size);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
#include "traccc/edm/measurement.hpp"
#include "traccc/edm/track_parameters.hpp"

// Covfie include(s).
#include <covfie/core/field.hpp>

namespace traccc::device {

/// Function for propagating the kalman-updated tracks to the next surface
Expand All @@ -35,10 +38,11 @@ namespace traccc::device {
/// @param[out] tips_view Tip link container for the current step
/// @param[out] n_out_params The number of output parameters
///
template <typename propagator_t, typename config_t>
template <typename propagator_t, typename bfield_t, typename config_t>
TRACCC_DEVICE inline void propagate_to_next_surface(
std::size_t globalIndex, const config_t cfg,
typename propagator_t::detector_type::view_type det_data,
bfield_t field_data,
vecmem::data::jagged_vector_view<typename propagator_t::intersection_type>
nav_candidates_buffer,
bound_track_parameters_collection_types::const_view in_params_view,
Expand Down
1 change: 1 addition & 0 deletions device/common/include/traccc/fitting/device/fit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ namespace traccc::device {
template <typename fitter_t, typename detector_view_t>
TRACCC_HOST_DEVICE inline void fit(
std::size_t globalIndex, detector_view_t det_data,
const typename fitter_t::bfield_type field_data,
const typename fitter_t::config_type cfg,
vecmem::data::jagged_vector_view<typename fitter_t::intersection_type>
nav_candidates_buffer,
Expand Down
3 changes: 2 additions & 1 deletion device/common/include/traccc/fitting/device/impl/fit.ipp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace traccc::device {
template <typename fitter_t, typename detector_view_t>
TRACCC_HOST_DEVICE inline void fit(
std::size_t globalIndex, detector_view_t det_data,
const typename fitter_t::bfield_type field_data,
const typename fitter_t::config_type cfg,
vecmem::data::jagged_vector_view<typename fitter_t::intersection_type>
nav_candidates_buffer,
Expand All @@ -28,7 +29,7 @@ TRACCC_HOST_DEVICE inline void fit(

track_state_container_types::device track_states(track_states_view);

fitter_t fitter(det, cfg);
fitter_t fitter(det, field_data, cfg);

if (globalIndex >= track_states.size()) {
return;
Expand Down
5 changes: 5 additions & 0 deletions device/cuda/include/traccc/cuda/finding/finding_algorithm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ template <typename stepper_t, typename navigator_t>
class finding_algorithm
: public algorithm<track_candidate_container_types::buffer(
const typename navigator_t::detector_type::view_type&,
const typename stepper_t::magnetic_field_type&,
const vecmem::data::jagged_vector_view<
typename navigator_t::intersection_type>&,
const typename measurement_collection_types::view&,
Expand All @@ -49,6 +50,9 @@ class finding_algorithm
/// Detector type
using detector_type = typename navigator_t::detector_type;

/// Field type
using bfield_type = typename stepper_t::magnetic_field_type;

/// Actor types
using interactor = detray::pointwise_material_interactor<transform3_type>;

Expand Down Expand Up @@ -86,6 +90,7 @@ class finding_algorithm
/// @param seeds Input seeds
track_candidate_container_types::buffer operator()(
const typename detector_type::view_type& det_view,
const bfield_type& field_view,
const vecmem::data::jagged_vector_view<
typename navigator_t::intersection_type>& navigation_buffer,
const typename measurement_collection_types::view& measurements,
Expand Down
2 changes: 2 additions & 0 deletions device/cuda/include/traccc/cuda/fitting/fitting_algorithm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ template <typename fitter_t>
class fitting_algorithm
: public algorithm<track_state_container_types::buffer(
const typename fitter_t::detector_type::view_type&,
const typename fitter_t::bfield_type&,
const vecmem::data::jagged_vector_view<
typename fitter_t::intersection_type>&,
const typename track_candidate_container_types::const_view&)> {
Expand All @@ -46,6 +47,7 @@ class fitting_algorithm
/// Run the algorithm
track_state_container_types::buffer operator()(
const typename fitter_t::detector_type::view_type& det_view,
const typename fitter_t::bfield_type& field_view,
const vecmem::data::jagged_vector_view<
typename fitter_t::intersection_type>& navigation_buffer,
const typename track_candidate_container_types::const_view&
Expand Down
22 changes: 13 additions & 9 deletions device/cuda/src/finding/finding_algorithm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,11 @@ __global__ void find_tracks(
}

/// CUDA kernel for running @c traccc::device::propagate_to_next_surface
template <typename propagator_t, typename config_t>
template <typename propagator_t, typename bfield_t, typename config_t>
__global__ void propagate_to_next_surface(
const config_t cfg,
typename propagator_t::detector_type::view_type det_data,
bfield_t field_data,
vecmem::data::jagged_vector_view<typename propagator_t::intersection_type>
nav_candidates_buffer,
bound_track_parameters_collection_types::const_view in_params_view,
Expand All @@ -133,10 +134,10 @@ __global__ void propagate_to_next_surface(

int gid = threadIdx.x + blockIdx.x * blockDim.x;

device::propagate_to_next_surface<propagator_t, config_t>(
gid, cfg, det_data, nav_candidates_buffer, in_params_view, links_view,
step, n_candidates, out_params_view, param_to_link_view, tips_view,
n_out_params);
device::propagate_to_next_surface<propagator_t, bfield_t, config_t>(
gid, cfg, det_data, field_data, nav_candidates_buffer, in_params_view,
links_view, step, n_candidates, out_params_view, param_to_link_view,
tips_view, n_out_params);
}

/// CUDA kernel for running @c traccc::device::build_tracks
Expand Down Expand Up @@ -174,6 +175,7 @@ template <typename stepper_t, typename navigator_t>
track_candidate_container_types::buffer
finding_algorithm<stepper_t, navigator_t>::operator()(
const typename detector_type::view_type& det_view,
const bfield_type& field_view,
const vecmem::data::jagged_vector_view<
typename navigator_t::intersection_type>& navigation_buffer,
const typename measurement_collection_types::view& measurements,
Expand Down Expand Up @@ -400,11 +402,13 @@ finding_algorithm<stepper_t, navigator_t>::operator()(
if (global_counter_host.n_candidates > 0) {
nBlocks =
(global_counter_host.n_candidates + nThreads - 1) / nThreads;
kernels::propagate_to_next_surface<propagator_type, config_type>
kernels::propagate_to_next_surface<propagator_type, bfield_type,
config_type>
<<<nBlocks, nThreads>>>(
m_cfg, det_view, navigation_buffer, updated_params_buffer,
link_map[step], step, (*global_counter_device).n_candidates,
out_params_buffer, param_to_link_map[step], tips_map[step],
m_cfg, det_view, field_view, navigation_buffer,
updated_params_buffer, link_map[step], step,
(*global_counter_device).n_candidates, out_params_buffer,
param_to_link_map[step], tips_map[step],
(*global_counter_device).n_out_params);
CUDA_ERROR_CHECK(cudaGetLastError());
CUDA_ERROR_CHECK(cudaDeviceSynchronize());
Expand Down
30 changes: 16 additions & 14 deletions device/cuda/src/fitting/fitting_algorithm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "traccc/fitting/kalman_filter/kalman_fitter.hpp"

// detray include(s).
#include "detray/detectors/bfield.hpp"
#include "detray/detectors/telescope_metadata.hpp"
#include "detray/detectors/toy_metadata.hpp"
#include "detray/masks/unbounded.hpp"
Expand All @@ -26,15 +27,16 @@ namespace kernels {

template <typename fitter_t, typename detector_view_t>
__global__ void fit(
detector_view_t det_data, const typename fitter_t::config_type cfg,
detector_view_t det_data, const typename fitter_t::bfield_type field_data,
const typename fitter_t::config_type cfg,
vecmem::data::jagged_vector_view<typename fitter_t::intersection_type>
nav_candidates_buffer,
track_candidate_container_types::const_view track_candidates_view,
track_state_container_types::view track_states_view) {

int gid = threadIdx.x + blockIdx.x * blockDim.x;

device::fit<fitter_t>(gid, det_data, cfg, nav_candidates_buffer,
device::fit<fitter_t>(gid, det_data, field_data, cfg, nav_candidates_buffer,
track_candidates_view, track_states_view);
}

Expand All @@ -56,6 +58,7 @@ fitting_algorithm<fitter_t>::fitting_algorithm(
template <typename fitter_t>
track_state_container_types::buffer fitting_algorithm<fitter_t>::operator()(
const typename fitter_t::detector_type::view_type& det_view,
const typename fitter_t::bfield_type& field_view,
const vecmem::data::jagged_vector_view<
typename fitter_t::intersection_type>& navigation_buffer,
const typename track_candidate_container_types::const_view&
Expand Down Expand Up @@ -86,9 +89,9 @@ track_state_container_types::buffer fitting_algorithm<fitter_t>::operator()(
const unsigned int nBlocks = (n_tracks + nThreads - 1) / nThreads;

// Run the track fitting
kernels::fit<fitter_t>
<<<nBlocks, nThreads>>>(det_view, m_cfg, navigation_buffer,
track_candidates_view, track_states_buffer);
kernels::fit<fitter_t><<<nBlocks, nThreads>>>(
det_view, field_view, m_cfg, navigation_buffer,
track_candidates_view, track_states_buffer);
CUDA_ERROR_CHECK(cudaGetLastError());
CUDA_ERROR_CHECK(cudaDeviceSynchronize());
}
Expand All @@ -97,21 +100,20 @@ track_state_container_types::buffer fitting_algorithm<fitter_t>::operator()(

// Explicit template instantiation
using toy_detector_type =
detray::detector<detray::toy_metadata<>, covfie::field_view,
detray::device_container_types>;
using toy_stepper_type = detray::rk_stepper<
covfie::field<toy_detector_type::bfield_backend_type>::view_t, transform3,
detray::constrained_step<>>;
detray::detector<detray::toy_metadata, detray::device_container_types>;
using toy_stepper_type =
detray::rk_stepper<covfie::field_view<detray::bfield::const_bknd_t>,
transform3, detray::constrained_step<>>;
using toy_navigator_type = detray::navigator<const toy_detector_type>;
using toy_fitter_type = kalman_fitter<toy_stepper_type, toy_navigator_type>;
template class fitting_algorithm<toy_fitter_type>;

using device_detector_type =
detray::detector<detray::telescope_metadata<detray::rectangle2D<>>,
covfie::field_view, detray::device_container_types>;
using rk_stepper_type = detray::rk_stepper<
covfie::field<device_detector_type::bfield_backend_type>::view_t,
transform3, detray::constrained_step<>>;
detray::device_container_types>;
using rk_stepper_type =
detray::rk_stepper<covfie::field_view<detray::bfield::const_bknd_t>,
transform3, detray::constrained_step<>>;
using device_navigator_type = detray::navigator<const device_detector_type>;
using device_fitter_type =
kalman_fitter<rk_stepper_type, device_navigator_type>;
Expand Down
Loading

0 comments on commit 4eaa933

Please sign in to comment.