diff --git a/src/rail/stages/__init__.py b/src/rail/stages/__init__.py index dc215d53..49ce4859 100644 --- a/src/rail/stages/__init__.py +++ b/src/rail/stages/__init__.py @@ -1,5 +1,7 @@ import rail + from rail.core import RailEnv +from rail.core.stage import RailStage from rail.estimation.estimator import CatEstimator from rail.estimation.classifier import CatClassifier, PZClassifier @@ -78,6 +80,6 @@ def import_and_attach_all(): """Import all the packages in the rail ecosystem and attach them to this module""" RailEnv.import_all_packages() RailEnv.attach_stages(rail.stages) - for xx in dir(rail.stages): + for xx in RailStage.pipeline_stages: rail.stages.__all__.append(xx)