diff --git a/src/rail/stages/__init__.py b/src/rail/stages/__init__.py index 93ec5f07..49ce4859 100644 --- a/src/rail/stages/__init__.py +++ b/src/rail/stages/__init__.py @@ -1,5 +1,7 @@ import rail + from rail.core import RailEnv +from rail.core.stage import RailStage from rail.estimation.estimator import CatEstimator from rail.estimation.classifier import CatClassifier, PZClassifier @@ -34,13 +36,6 @@ from rail.tools.table_tools import ColumnMapper, RowSelector, TableConverter - -def import_and_attach_all(): - """Import all the packages in the rail ecosystem and attach them to this module""" - RailEnv.import_all_packages() - RailEnv.attach_stages(rail.stages) - - __all__ = [ "CatEstimator", "CatClassifier", @@ -79,3 +74,12 @@ def import_and_attach_all(): "RowSelector", "TableConverter", ] + + +def import_and_attach_all(): + """Import all the packages in the rail ecosystem and attach them to this module""" + RailEnv.import_all_packages() + RailEnv.attach_stages(rail.stages) + for xx in RailStage.pipeline_stages: + rail.stages.__all__.append(xx) +