diff --git a/pysr/julia_extensions.py b/pysr/julia_extensions.py index 950c292e..22036cb7 100644 --- a/pysr/julia_extensions.py +++ b/pysr/julia_extensions.py @@ -2,10 +2,20 @@ from typing import Literal +from .julia_helpers import KNOWN_CLUSTERMANAGER_BACKENDS from .julia_import import Pkg, jl from .julia_registry_helpers import try_with_registry_fallback from .logger_specs import AbstractLoggerSpec, TensorBoardLoggerSpec +PACKAGE_UUIDS = { + "LoopVectorization": "bdcacae8-1622-11e9-2a5c-532679323890", + "Bumper": "8ce10254-0962-460f-a3d8-1f77fea1446e", + "Zygote": "e88e6eb3-aa80-5325-afca-941959d7151f", + "SlurmClusterManager": "c82cd089-7bf7-41d7-976b-6b5d413cbe0a", + "ClusterManagers": "34f1f09b-3a8b-5176-ab39-66d58a4d544e", + "TensorBoardLogger": "899adc3e-224a-11e9-021f-63837185c80f", +} + def load_required_packages( *, @@ -16,26 +26,24 @@ def load_required_packages( logger_spec: AbstractLoggerSpec | None = None, ): if turbo: - load_package("LoopVectorization", "bdcacae8-1622-11e9-2a5c-532679323890") + load_package("LoopVectorization") if bumper: - load_package("Bumper", "8ce10254-0962-460f-a3d8-1f77fea1446e") + load_package("Bumper") if autodiff_backend is not None: - load_package("Zygote", "e88e6eb3-aa80-5325-afca-941959d7151f") + load_package("Zygote") if cluster_manager is not None: - load_package("ClusterManagers", "34f1f09b-3a8b-5176-ab39-66d58a4d544e") + if cluster_manager == "slurm_native": + load_package("SlurmClusterManager") + elif cluster_manager in KNOWN_CLUSTERMANAGER_BACKENDS: + load_package("ClusterManagers") if isinstance(logger_spec, TensorBoardLoggerSpec): - load_package("TensorBoardLogger", "899adc3e-224a-11e9-021f-63837185c80f") + load_package("TensorBoardLogger") def load_all_packages(): """Install and load all Julia extensions available to PySR.""" - load_required_packages( - turbo=True, - bumper=True, - autodiff_backend="Zygote", - cluster_manager="slurm", - logger_spec=TensorBoardLoggerSpec(log_dir="logs"), - ) + for package_name, uuid_s in PACKAGE_UUIDS.items(): + load_package(package_name, uuid_s) # TODO: Refactor this file so we can install all packages at once using `juliapkg`, @@ -46,7 +54,8 @@ def isinstalled(uuid_s: str): return jl.haskey(Pkg.dependencies(), jl.Base.UUID(uuid_s)) -def load_package(package_name: str, uuid_s: str) -> None: +def load_package(package_name: str, uuid_s: str | None = None) -> None: + uuid_s = uuid_s or PACKAGE_UUIDS[package_name] if not isinstalled(uuid_s): def _add_package(): diff --git a/pysr/julia_helpers.py b/pysr/julia_helpers.py index ef82be90..5842496e 100644 --- a/pysr/julia_helpers.py +++ b/pysr/julia_helpers.py @@ -29,9 +29,21 @@ def _escape_filename(filename): return str_repr -def _load_cluster_manager(cluster_manager: str): - jl.seval(f"using ClusterManagers: addprocs_{cluster_manager}") - return jl.seval(f"addprocs_{cluster_manager}") +KNOWN_CLUSTERMANAGER_BACKENDS = ["slurm", "pbs", "lsf", "sge", "qrsh", "scyld", "htc"] + + +def load_cluster_manager(cluster_manager: str) -> AnyValue: + if cluster_manager == "slurm_native": + jl.seval("using SlurmClusterManager: SlurmManager") + # TODO: Is this the right way to do this? + jl.seval("addprocs_slurm_native(; _...) = addprocs(SlurmManager())") + return jl.addprocs_slurm_native + elif cluster_manager in KNOWN_CLUSTERMANAGER_BACKENDS: + jl.seval(f"using ClusterManagers: addprocs_{cluster_manager}") + return jl.seval(f"addprocs_{cluster_manager}") + else: + # Assume it's a function + return jl.seval(cluster_manager) def jl_array(x, dtype=None): diff --git a/pysr/sr.py b/pysr/sr.py index 00209378..8da6cfee 100644 --- a/pysr/sr.py +++ b/pysr/sr.py @@ -41,11 +41,11 @@ from .julia_extensions import load_required_packages from .julia_helpers import ( _escape_filename, - _load_cluster_manager, jl_array, jl_deserialize, jl_is_function, jl_serialize, + load_cluster_manager, ) from .julia_import import AnyValue, SymbolicRegression, VectorValue, jl from .logger_specs import AbstractLoggerSpec @@ -549,8 +549,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator): Default is `None`. cluster_manager : str For distributed computing, this sets the job queue system. Set - to one of "slurm", "pbs", "lsf", "sge", "qrsh", "scyld", or - "htc". If set to one of these, PySR will run in distributed + to one of "slurm_native", "slurm", "pbs", "lsf", "sge", "qrsh", "scyld", + or "htc". If set to one of these, PySR will run in distributed mode, and use `procs` to figure out how many processes to launch. Default is `None`. heap_size_hint_in_bytes : int @@ -849,13 +849,11 @@ def __init__( probability_negate_constant: float = 0.00743, tournament_selection_n: int = 15, tournament_selection_p: float = 0.982, - parallelism: ( - Literal["serial", "multithreading", "multiprocessing"] | None - ) = None, + # fmt: off + parallelism: Literal["serial", "multithreading", "multiprocessing"] | None = None, procs: int | None = None, - cluster_manager: ( - Literal["slurm", "pbs", "lsf", "sge", "qrsh", "scyld", "htc"] | None - ) = None, + cluster_manager: Literal["slurm_native", "slurm", "pbs", "lsf", "sge", "qrsh", "scyld", "htc"] | str | None = None, + # fmt: on heap_size_hint_in_bytes: int | None = None, batching: bool = False, batch_size: int = 50, @@ -1842,7 +1840,7 @@ def _run( raise ValueError( "To use cluster managers, you must set `parallelism='multiprocessing'`." ) - cluster_manager = _load_cluster_manager(cluster_manager) + cluster_manager = load_cluster_manager(cluster_manager) # TODO(mcranmer): These functions should be part of this class. binary_operators, unary_operators = _maybe_create_inline_operators(