From 98cca40929502a29c1893e0a9e70e83080eb2eda Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 17 Sep 2023 13:30:15 +0100 Subject: [PATCH 01/17] Refactor utility functions --- pysr/sr.py | 60 +++++++-------------------------------------------- pysr/utils.py | 55 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 52 deletions(-) create mode 100644 pysr/utils.py diff --git a/pysr/sr.py b/pysr/sr.py index e7383dcc3..777458014 100644 --- a/pysr/sr.py +++ b/pysr/sr.py @@ -33,6 +33,12 @@ init_julia, is_julia_version_greater_eq, ) +from .utils import ( + _csv_filename_to_pkl_filename, + _preprocess_julia_floats, + _safe_check_feature_names_in, + _subscriptify, +) Main = None # TODO: Rename to more descriptive name like "julia_runtime" @@ -945,10 +951,8 @@ def from_file( model : PySRRegressor The model with fitted equations. """ - if os.path.splitext(equation_file)[1] != ".pkl": - pkl_filename = _csv_filename_to_pkl_filename(equation_file) - else: - pkl_filename = equation_file + + pkl_filename = _csv_filename_to_pkl_filename(equation_file) # Try to load model from .pkl print(f"Checking if {pkl_filename} exists...") @@ -2437,51 +2441,3 @@ def run_feature_selection(X, y, select_k_features, random_state=None): clf, threshold=-np.inf, max_features=select_k_features, prefit=True ) return selector.get_support(indices=True) - - -def _csv_filename_to_pkl_filename(csv_filename) -> str: - # Assume that the csv filename is of the form "foo.csv" - assert str(csv_filename).endswith(".csv") - - dirname = str(os.path.dirname(csv_filename)) - basename = str(os.path.basename(csv_filename)) - base = str(os.path.splitext(basename)[0]) - - pkl_basename = base + ".pkl" - - return os.path.join(dirname, pkl_basename) - - -_regexp_im = re.compile(r"\b(\d+\.\d+)im\b") -_regexp_im_sci = re.compile(r"\b(\d+\.\d+)[eEfF]([+-]?\d+)im\b") -_regexp_sci = re.compile(r"\b(\d+\.\d+)[eEfF]([+-]?\d+)\b") - -_apply_regexp_im = lambda x: _regexp_im.sub(r"\1j", x) -_apply_regexp_im_sci = lambda x: _regexp_im_sci.sub(r"\1e\2j", x) -_apply_regexp_sci = lambda x: _regexp_sci.sub(r"\1e\2", x) - - -def _preprocess_julia_floats(s: str) -> str: - if isinstance(s, str): - s = _apply_regexp_im(s) - s = _apply_regexp_im_sci(s) - s = _apply_regexp_sci(s) - return s - - -def _subscriptify(i: int) -> str: - """Converts integer to subscript text form. - - For example, 123 -> "₁₂₃". - """ - return "".join([chr(0x2080 + int(c)) for c in str(i)]) - - -def _safe_check_feature_names_in(self, variable_names, generate_names=True): - """_check_feature_names_in with compat for old versions.""" - try: - return _check_feature_names_in( - self, variable_names, generate_names=generate_names - ) - except TypeError: - return _check_feature_names_in(self, variable_names) diff --git a/pysr/utils.py b/pysr/utils.py new file mode 100644 index 000000000..ca000aae7 --- /dev/null +++ b/pysr/utils.py @@ -0,0 +1,55 @@ +import os +import re + +from sklearn.utils.validation import _check_feature_names_in + + +def _csv_filename_to_pkl_filename(csv_filename: str) -> str: + if os.path.splitext(csv_filename)[1] == ".pkl": + return csv_filename + + # Assume that the csv filename is of the form "foo.csv" + assert str(csv_filename).endswith(".csv") + + dirname = str(os.path.dirname(csv_filename)) + basename = str(os.path.basename(csv_filename)) + base = str(os.path.splitext(basename)[0]) + + pkl_basename = base + ".pkl" + + return os.path.join(dirname, pkl_basename) + + +_regexp_im = re.compile(r"\b(\d+\.\d+)im\b") +_regexp_im_sci = re.compile(r"\b(\d+\.\d+)[eEfF]([+-]?\d+)im\b") +_regexp_sci = re.compile(r"\b(\d+\.\d+)[eEfF]([+-]?\d+)\b") + +_apply_regexp_im = lambda x: _regexp_im.sub(r"\1j", x) +_apply_regexp_im_sci = lambda x: _regexp_im_sci.sub(r"\1e\2j", x) +_apply_regexp_sci = lambda x: _regexp_sci.sub(r"\1e\2", x) + + +def _preprocess_julia_floats(s: str) -> str: + if isinstance(s, str): + s = _apply_regexp_im(s) + s = _apply_regexp_im_sci(s) + s = _apply_regexp_sci(s) + return s + + +def _safe_check_feature_names_in(self, variable_names, generate_names=True): + """_check_feature_names_in with compat for old versions.""" + try: + return _check_feature_names_in( + self, variable_names, generate_names=generate_names + ) + except TypeError: + return _check_feature_names_in(self, variable_names) + + +def _subscriptify(i: int) -> str: + """Converts integer to subscript text form. + + For example, 123 -> "₁₂₃". + """ + return "".join([chr(0x2080 + int(c)) for c in str(i)]) From 47136078f50b5598e7530968cdc164bd7e156993 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 17 Sep 2023 13:37:58 +0100 Subject: [PATCH 02/17] Move denoising functionality to separate file --- pysr/denoising.py | 35 +++++++++++++++++++++++++++++++++++ pysr/sr.py | 31 ++++--------------------------- 2 files changed, 39 insertions(+), 27 deletions(-) create mode 100644 pysr/denoising.py diff --git a/pysr/denoising.py b/pysr/denoising.py new file mode 100644 index 000000000..b65484529 --- /dev/null +++ b/pysr/denoising.py @@ -0,0 +1,35 @@ +"""Functions for denoising data during preprocessing.""" +import numpy as np + + +def denoise(X, y, Xresampled=None, random_state=None): + """Denoise the dataset using a Gaussian process.""" + from sklearn.gaussian_process import GaussianProcessRegressor + from sklearn.gaussian_process.kernels import RBF, ConstantKernel, WhiteKernel + + gp_kernel = RBF(np.ones(X.shape[1])) + WhiteKernel(1e-1) + ConstantKernel() + gpr = GaussianProcessRegressor( + kernel=gp_kernel, n_restarts_optimizer=50, random_state=random_state + ) + gpr.fit(X, y) + + if Xresampled is not None: + return Xresampled, gpr.predict(Xresampled) + + return X, gpr.predict(X) + + +def multi_denoise(X, y, Xresampled=None, random_state=None): + """Perform `denoise` along each column of `y` independently.""" + y = np.stack( + [ + denoise(X, y[:, i], Xresampled=Xresampled, random_state=random_state)[1] + for i in range(y.shape[1]) + ], + axis=1, + ) + + if Xresampled is not None: + return Xresampled, y + + return X, y diff --git a/pysr/sr.py b/pysr/sr.py index 777458014..eeb8be22d 100644 --- a/pysr/sr.py +++ b/pysr/sr.py @@ -18,6 +18,7 @@ from sklearn.utils import check_array, check_consistent_length, check_random_state from sklearn.utils.validation import _check_feature_names_in, check_is_fitted +from .denoising import denoise, multi_denoise from .deprecated import make_deprecated_kwargs_for_pysr_regressor from .export_jax import sympy2jax from .export_latex import sympy2latex, sympy2latextable, sympy2multilatextable @@ -1506,19 +1507,11 @@ def _pre_transform_training_data( # Denoising transformation if self.denoise: if self.nout_ > 1: - y = np.stack( - [ - _denoise( - X, y[:, i], Xresampled=Xresampled, random_state=random_state - )[1] - for i in range(self.nout_) - ], - axis=1, + X, y = multi_denoise( + X, y, Xresampled=Xresampled, random_state=random_state ) - if Xresampled is not None: - X = Xresampled else: - X, y = _denoise(X, y, Xresampled=Xresampled, random_state=random_state) + X, y = denoise(X, y, Xresampled=Xresampled, random_state=random_state) return X, y, variable_names, X_units, y_units @@ -2394,22 +2387,6 @@ def idx_model_selection(equations: pd.DataFrame, model_selection: str) -> int: return chosen_idx -def _denoise(X, y, Xresampled=None, random_state=None): - """Denoise the dataset using a Gaussian process.""" - from sklearn.gaussian_process import GaussianProcessRegressor - from sklearn.gaussian_process.kernels import RBF, ConstantKernel, WhiteKernel - - gp_kernel = RBF(np.ones(X.shape[1])) + WhiteKernel(1e-1) + ConstantKernel() - gpr = GaussianProcessRegressor( - kernel=gp_kernel, n_restarts_optimizer=50, random_state=random_state - ) - gpr.fit(X, y) - if Xresampled is not None: - return Xresampled, gpr.predict(Xresampled) - - return X, gpr.predict(X) - - # Function has not been removed only due to usage in module tests def _handle_feature_selection(X, select_k_features, y, variable_names): if select_k_features is not None: From 3ae241a1baa2a8c8e0a149ed7633e43d2f4e71f6 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 17 Sep 2023 13:42:11 +0100 Subject: [PATCH 03/17] Move feature selection functionality to separate file --- pysr/feature_selection.py | 35 +++++++++++++++++++++++++++++++++++ pysr/sr.py | 34 +--------------------------------- pysr/test/test.py | 11 +++-------- 3 files changed, 39 insertions(+), 41 deletions(-) create mode 100644 pysr/feature_selection.py diff --git a/pysr/feature_selection.py b/pysr/feature_selection.py new file mode 100644 index 000000000..a483e7e7f --- /dev/null +++ b/pysr/feature_selection.py @@ -0,0 +1,35 @@ +"""Functions for doing feature selection during preprocessing.""" +import numpy as np + + +def run_feature_selection(X, y, select_k_features, random_state=None) -> np.ndarray: + """ + Find most important features. + + Uses a gradient boosting tree regressor as a proxy for finding + the k most important features in X, returning indices for those + features as output. + """ + from sklearn.ensemble import RandomForestRegressor + from sklearn.feature_selection import SelectFromModel + + clf = RandomForestRegressor( + n_estimators=100, max_depth=3, random_state=random_state + ) + clf.fit(X, y) + selector = SelectFromModel( + clf, threshold=-np.inf, max_features=select_k_features, prefit=True + ) + return selector.get_support(indices=True) + + +# Function has not been removed only due to usage in module tests +def _handle_feature_selection(X, select_k_features, y, variable_names): + if select_k_features is not None: + selection = run_feature_selection(X, y, select_k_features) + print(f"Using features {[variable_names[i] for i in selection]}") + X = X[:, selection] + else: + selection = None + + return X, selection diff --git a/pysr/sr.py b/pysr/sr.py index eeb8be22d..e38800a98 100644 --- a/pysr/sr.py +++ b/pysr/sr.py @@ -25,6 +25,7 @@ from .export_numpy import sympy2numpy from .export_sympy import assert_valid_sympy_symbol, create_sympy_symbols, pysr2sympy from .export_torch import sympy2torch +from .feature_selection import run_feature_selection from .julia_helpers import ( _escape_filename, _load_backend, @@ -2385,36 +2386,3 @@ def idx_model_selection(equations: pd.DataFrame, model_selection: str) -> int: f"{model_selection} is not a valid model selection strategy." ) return chosen_idx - - -# Function has not been removed only due to usage in module tests -def _handle_feature_selection(X, select_k_features, y, variable_names): - if select_k_features is not None: - selection = run_feature_selection(X, y, select_k_features) - print(f"Using features {[variable_names[i] for i in selection]}") - X = X[:, selection] - - else: - selection = None - return X, selection - - -def run_feature_selection(X, y, select_k_features, random_state=None): - """ - Find most important features. - - Uses a gradient boosting tree regressor as a proxy for finding - the k most important features in X, returning indices for those - features as output. - """ - from sklearn.ensemble import RandomForestRegressor - from sklearn.feature_selection import SelectFromModel - - clf = RandomForestRegressor( - n_estimators=100, max_depth=3, random_state=random_state - ) - clf.fit(X, y) - selector = SelectFromModel( - clf, threshold=-np.inf, max_features=select_k_features, prefit=True - ) - return selector.get_support(indices=True) diff --git a/pysr/test/test.py b/pysr/test/test.py index 31bfcfa4c..120f34d08 100644 --- a/pysr/test/test.py +++ b/pysr/test/test.py @@ -14,14 +14,9 @@ from .. import PySRRegressor, julia_helpers from ..export_latex import sympy2latex -from ..sr import ( - _check_assertions, - _csv_filename_to_pkl_filename, - _handle_feature_selection, - _process_constraints, - idx_model_selection, - run_feature_selection, -) +from ..feature_selection import _handle_feature_selection, run_feature_selection +from ..sr import _check_assertions, _process_constraints, idx_model_selection +from ..utils import _csv_filename_to_pkl_filename DEFAULT_PARAMS = inspect.signature(PySRRegressor.__init__).parameters DEFAULT_NITERATIONS = DEFAULT_PARAMS["niterations"].default From ff2ef42f8e2cb11e59b6bd26b5fa4f88699aa549 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 17 Sep 2023 14:40:27 +0100 Subject: [PATCH 04/17] Mypy compatibility --- .github/workflows/CI.yml | 26 ++++++++++++++++++++++++++ mypy.ini | 8 ++++++++ pysr/export_latex.py | 24 +++++++++++++----------- pysr/export_sympy.py | 4 ++-- pysr/feature_selection.py | 2 +- pysr/sr.py | 11 ++++++----- 6 files changed, 56 insertions(+), 19 deletions(-) create mode 100644 mypy.ini diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index b863a4d8f..91ca092a5 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -143,3 +143,29 @@ jobs: run: | pip install coveralls coveralls --finish + + types: + name: Check types + runs-on: ubuntu-latest + defaults: + run: + shell: bash -l {0} + strategy: + matrix: + python-version: ['3.10'] + + steps: + - uses: actions/checkout@v3 + - name: "Set up Python" + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + cache: pip + - name: "Install PySR and all dependencies" + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install mypy jax jaxlib torch + python setup.py install + - name: "Run mypy" + run: mypy --install-types --non-interactive pysr diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 000000000..850edc545 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,8 @@ +[mypy] +warn_return_any = True + +[mypy-sklearn.*] +ignore_missing_imports = True + +[mypy-julia.*] +ignore_missing_imports = True diff --git a/pysr/export_latex.py b/pysr/export_latex.py index bb655d658..0316f872a 100644 --- a/pysr/export_latex.py +++ b/pysr/export_latex.py @@ -1,5 +1,5 @@ """Functions to help export PySR equations to LaTeX.""" -from typing import List +from typing import List, Optional, Tuple import pandas as pd import sympy @@ -19,14 +19,16 @@ def _print_Float(self, expr): return super()._print_Float(reduced_float) -def sympy2latex(expr, prec=3, full_prec=True, **settings): +def sympy2latex(expr, prec=3, full_prec=True, **settings) -> str: """Convert sympy expression to LaTeX with custom precision.""" settings["full_prec"] = full_prec printer = PreciseLatexPrinter(settings=settings, prec=prec) return printer.doprint(expr) -def generate_table_environment(columns=["equation", "complexity", "loss"]): +def generate_table_environment( + columns: List[str] = ["equation", "complexity", "loss"] +) -> Tuple[str, str]: margins = "c" * len(columns) column_map = { "complexity": "Complexity", @@ -58,12 +60,12 @@ def generate_table_environment(columns=["equation", "complexity", "loss"]): def sympy2latextable( equations: pd.DataFrame, - indices: List[int] = None, + indices: Optional[List[int]] = None, precision: int = 3, - columns=["equation", "complexity", "loss", "score"], + columns: List[str] = ["equation", "complexity", "loss", "score"], max_equation_length: int = 50, output_variable_name: str = "y", -): +) -> str: """Generate a booktabs-style LaTeX table for a single set of equations.""" assert isinstance(equations, pd.DataFrame) @@ -71,7 +73,7 @@ def sympy2latextable( latex_table_content = [] if indices is None: - indices = range(len(equations)) + indices = list(equations.index) for i in indices: latex_equation = sympy2latex( @@ -126,11 +128,11 @@ def sympy2latextable( def sympy2multilatextable( equations: List[pd.DataFrame], - indices: List[List[int]] = None, + indices: Optional[List[List[int]]] = None, precision: int = 3, - columns=["equation", "complexity", "loss", "score"], - output_variable_names: str = None, -): + columns: List[str] = ["equation", "complexity", "loss", "score"], + output_variable_names: Optional[List[str]] = None, +) -> str: """Generate multiple latex tables for a list of equation sets.""" # TODO: Let user specify custom output variable diff --git a/pysr/export_sympy.py b/pysr/export_sympy.py index 37bd99da0..81142f481 100644 --- a/pysr/export_sympy.py +++ b/pysr/export_sympy.py @@ -51,14 +51,14 @@ def create_sympy_symbols( - feature_names_in: Optional[List[str]] = None, + feature_names_in: List[str], ) -> List[sympy.Symbol]: return [sympy.Symbol(variable) for variable in feature_names_in] def pysr2sympy( equation: str, *, extra_sympy_mappings: Optional[Dict[str, Callable]] = None -) -> sympy.Expr: +): local_sympy_mappings = { **(extra_sympy_mappings if extra_sympy_mappings else {}), **sympy_mappings, diff --git a/pysr/feature_selection.py b/pysr/feature_selection.py index a483e7e7f..a6ebf0390 100644 --- a/pysr/feature_selection.py +++ b/pysr/feature_selection.py @@ -2,7 +2,7 @@ import numpy as np -def run_feature_selection(X, y, select_k_features, random_state=None) -> np.ndarray: +def run_feature_selection(X, y, select_k_features, random_state=None): """ Find most important features. diff --git a/pysr/sr.py b/pysr/sr.py index e38800a98..71df6cacc 100644 --- a/pysr/sr.py +++ b/pysr/sr.py @@ -11,6 +11,7 @@ from io import StringIO from multiprocessing import cpu_count from pathlib import Path +from typing import List, Optional import numpy as np import pandas as pd @@ -1781,10 +1782,10 @@ def fit( y, Xresampled=None, weights=None, - variable_names=None, - X_units=None, - y_units=None, - ): + variable_names: Optional[List[str]] = None, + X_units: Optional[List[str]] = None, + y_units: Optional[List[str]] = None, + ) -> "PySRRegressor": """ Search for equations to fit the dataset and store them in `self.equations_`. @@ -2371,7 +2372,7 @@ def latex_table( return "\n".join(preamble_string + [table_string]) -def idx_model_selection(equations: pd.DataFrame, model_selection: str) -> int: +def idx_model_selection(equations: pd.DataFrame, model_selection: str): """Select an expression and return its index.""" if model_selection == "accuracy": chosen_idx = equations["loss"].idxmin() From 135a4641caa97a26798143c162b374ec3e24388b Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 17 Sep 2023 15:01:52 +0100 Subject: [PATCH 05/17] Move all deprecated functions to deprecated.py --- pysr/__init__.py | 3 ++- pysr/deprecated.py | 54 ++++++++++++++++++++++++++++++++++++++++ pysr/feynman_problems.py | 2 +- pysr/sr.py | 42 ------------------------------- 4 files changed, 57 insertions(+), 44 deletions(-) diff --git a/pysr/__init__.py b/pysr/__init__.py index 5f2200356..99c6a9742 100644 --- a/pysr/__init__.py +++ b/pysr/__init__.py @@ -1,9 +1,10 @@ from . import sklearn_monkeypatch +from .deprecated import best, best_callable, best_row, best_tex, pysr from .export_jax import sympy2jax from .export_torch import sympy2torch from .feynman_problems import FeynmanProblem, Problem from .julia_helpers import install -from .sr import PySRRegressor, best, best_callable, best_row, best_tex, pysr +from .sr import PySRRegressor from .version import __version__ __all__ = [ diff --git a/pysr/deprecated.py b/pysr/deprecated.py index ea5922729..ecfeb45e5 100644 --- a/pysr/deprecated.py +++ b/pysr/deprecated.py @@ -1,4 +1,58 @@ """Various functions to deprecate features.""" +import warnings + + +def pysr(X, y, weights=None, **kwargs): # pragma: no cover + from .sr import PySRRegressor + + warnings.warn( + "Calling `pysr` is deprecated. " + "Please use `model = PySRRegressor(**params); " + "model.fit(X, y)` going forward.", + FutureWarning, + ) + model = PySRRegressor(**kwargs) + model.fit(X, y, weights=weights) + return model.equations_ + + +def best(*args, **kwargs): # pragma: no cover + raise NotImplementedError( + "`best` has been deprecated. " + "Please use the `PySRRegressor` interface. " + "After fitting, you can return `.sympy()` " + "to get the sympy representation " + "of the best equation." + ) + + +def best_row(*args, **kwargs): # pragma: no cover + raise NotImplementedError( + "`best_row` has been deprecated. " + "Please use the `PySRRegressor` interface. " + "After fitting, you can run `print(model)` to view the best equation, " + "or " + "`model.get_best()` to return the best equation's " + "row in `model.equations_`." + ) + + +def best_tex(*args, **kwargs): # pragma: no cover + raise NotImplementedError( + "`best_tex` has been deprecated. " + "Please use the `PySRRegressor` interface. " + "After fitting, you can return `.latex()` to " + "get the sympy representation " + "of the best equation." + ) + + +def best_callable(*args, **kwargs): # pragma: no cover + raise NotImplementedError( + "`best_callable` has been deprecated. Please use the `PySRRegressor` " + "interface. After fitting, you can use " + "`.predict(X)` to use the best callable." + ) def make_deprecated_kwargs_for_pysr_regressor(): diff --git a/pysr/feynman_problems.py b/pysr/feynman_problems.py index a264a901b..b64b41397 100644 --- a/pysr/feynman_problems.py +++ b/pysr/feynman_problems.py @@ -4,7 +4,7 @@ import numpy as np -from .sr import best, pysr +from .deprecated import best, pysr PKG_DIR = Path(__file__).parents[1] FEYNMAN_DATASET = PKG_DIR / "datasets" / "FeynmanEquations.csv" diff --git a/pysr/sr.py b/pysr/sr.py index 71df6cacc..d824b52c2 100644 --- a/pysr/sr.py +++ b/pysr/sr.py @@ -48,17 +48,6 @@ already_ran = False -def pysr(X, y, weights=None, **kwargs): # pragma: no cover - warnings.warn( - "Calling `pysr` is deprecated. " - "Please use `model = PySRRegressor(**params); model.fit(X, y)` going forward.", - FutureWarning, - ) - model = PySRRegressor(**kwargs) - model.fit(X, y, weights=weights) - return model.equations_ - - def _process_constraints(binary_operators, unary_operators, constraints): constraints = constraints.copy() for op in unary_operators: @@ -181,37 +170,6 @@ def _check_assertions( ) -def best(*args, **kwargs): # pragma: no cover - raise NotImplementedError( - "`best` has been deprecated. Please use the `PySRRegressor` interface. " - "After fitting, you can return `.sympy()` to get the sympy representation " - "of the best equation." - ) - - -def best_row(*args, **kwargs): # pragma: no cover - raise NotImplementedError( - "`best_row` has been deprecated. Please use the `PySRRegressor` interface. " - "After fitting, you can run `print(model)` to view the best equation, or " - "`model.get_best()` to return the best equation's row in `model.equations_`." - ) - - -def best_tex(*args, **kwargs): # pragma: no cover - raise NotImplementedError( - "`best_tex` has been deprecated. Please use the `PySRRegressor` interface. " - "After fitting, you can return `.latex()` to get the sympy representation " - "of the best equation." - ) - - -def best_callable(*args, **kwargs): # pragma: no cover - raise NotImplementedError( - "`best_callable` has been deprecated. Please use the `PySRRegressor` " - "interface. After fitting, you can use `.predict(X)` to use the best callable." - ) - - # Class validation constants VALID_OPTIMIZER_ALGORITHMS = ["NelderMead", "BFGS"] From 6c92e1cdff74189664f1e46ef9a9b4d07981bb9b Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 17 Sep 2023 16:13:11 +0100 Subject: [PATCH 06/17] Store `sr_options_` and rename state to `sr_state_` --- pysr/sr.py | 40 ++++++++++++++++++++++++++++------------ pysr/test/test.py | 4 ++-- 2 files changed, 30 insertions(+), 14 deletions(-) diff --git a/pysr/sr.py b/pysr/sr.py index d824b52c2..ef99693d6 100644 --- a/pysr/sr.py +++ b/pysr/sr.py @@ -603,8 +603,12 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator): Path to the temporary equations directory. equation_file_ : str Output equation file name produced by the julia backend. - raw_julia_state_ : tuple[list[PyCall.jlwrap], PyCall.jlwrap] + sr_state_ : tuple[list[PyCall.jlwrap], PyCall.jlwrap] The state for the julia SymbolicRegression.jl backend post fitting. + sr_options_ : PyCall.jlwrap + The options used by `SymbolicRegression.jl`, created during + a call to `.fit`. You may use this to manually call functions + in `SymbolicRegression` which take an `::Options` argument. equation_file_contents_ : list[pandas.DataFrame] Contents of the equation file output by the Julia backend. show_pickle_warnings_ : bool @@ -1031,7 +1035,7 @@ def __getstate__(self): serialization. Thus, for `PySRRegressor` to support pickle serialization, the - `raw_julia_state_` attribute must be hidden from pickle. This will + `sr_state_` attribute must be hidden from pickle. This will prevent the `warm_start` of any model that is loaded via `pickle.loads()`, but does allow all other attributes of a fitted `PySRRegressor` estimator to be serialized. Note: Jax and Torch format equations are also removed @@ -1041,9 +1045,9 @@ def __getstate__(self): show_pickle_warning = not ( "show_pickle_warnings_" in state and not state["show_pickle_warnings_"] ) - if "raw_julia_state_" in state and show_pickle_warning: + if ("sr_state_" in state or "sr_options_" in state) and show_pickle_warning: warnings.warn( - "raw_julia_state_ cannot be pickled and will be removed from the " + "sr_state_ and sr_options_ cannot be pickled and will be removed from the " "serialized instance. This will prevent a `warm_start` fit of any " "model that is deserialized via `pickle.load()`." ) @@ -1055,7 +1059,10 @@ def __getstate__(self): "serialized instance. When loading the model, please redefine " f"`{state_key}` at runtime." ) - state_keys_to_clear = ["raw_julia_state_"] + state_keys_containing_lambdas + state_keys_to_clear = [ + "sr_state_", + "sr_options_", + ] + state_keys_containing_lambdas pickled_state = { key: (None if key in state_keys_to_clear else value) for key, value in state.items() @@ -1105,6 +1112,14 @@ def equations(self): # pragma: no cover ) return self.equations_ + @property + def raw_julia_state_(self): # pragma: no cover + warnings.warn( + "PySRRegressor.raw_julia_state_ is now deprecated. " + "Please use PySRRegressor.sr_state_ instead.", + ) + return self.sr_state_ + def get_best(self, index=None): """ Get best equation using `model_selection`. @@ -1605,7 +1620,7 @@ def _run(self, X, y, mutated_params, weights, seed): # Call to Julia backend. # See https://github.com/MilesCranmer/SymbolicRegression.jl/blob/master/src/OptionsStruct.jl - options = SymbolicRegression.Options( + self.sr_options_ = SymbolicRegression.Options( binary_operators=Main.eval(str(binary_operators).replace("'", "")), unary_operators=Main.eval(str(unary_operators).replace("'", "")), bin_constraints=bin_constraints, @@ -1704,7 +1719,7 @@ def _run(self, X, y, mutated_params, weights, seed): # Call to Julia backend. # See https://github.com/MilesCranmer/SymbolicRegression.jl/blob/master/src/SymbolicRegression.jl - self.raw_julia_state_ = SymbolicRegression.equation_search( + self.sr_state_ = SymbolicRegression.equation_search( Main.X, Main.y, weights=Main.weights, @@ -1714,10 +1729,10 @@ def _run(self, X, y, mutated_params, weights, seed): y_variable_names=y_variable_names, X_units=self.X_units_, y_units=self.y_units_, - options=options, + options=self.sr_options_, numprocs=cprocs, parallelism=parallelism, - saved_state=self.raw_julia_state_, + saved_state=self.sr_state_, return_state=True, addprocs_function=cluster_manager, progress=progress and self.verbosity > 0 and len(y.shape) == 1, @@ -1786,10 +1801,10 @@ def fit( Fitted estimator. """ # Init attributes that are not specified in BaseEstimator - if self.warm_start and hasattr(self, "raw_julia_state_"): + if self.warm_start and hasattr(self, "sr_state_"): pass else: - if hasattr(self, "raw_julia_state_"): + if hasattr(self, "sr_state_"): warnings.warn( "The discovered expressions are being reset. " "Please set `warm_start=True` if you wish to continue " @@ -1799,7 +1814,8 @@ def fit( self.equations_ = None self.nout_ = 1 self.selection_mask_ = None - self.raw_julia_state_ = None + self.sr_state_ = None + self.sr_options_ = None self.X_units_ = None self.y_units_ = None diff --git a/pysr/test/test.py b/pysr/test/test.py index 120f34d08..df361e6fd 100644 --- a/pysr/test/test.py +++ b/pysr/test/test.py @@ -109,7 +109,7 @@ def test_high_precision_search_custom_loss(self): from pysr.sr import Main # We should have that the model state is now a Float64 hof: - Main.test_state = model.raw_julia_state_ + Main.test_state = model.sr_state_ self.assertTrue(Main.eval("typeof(test_state[2]).parameters[1] == Float64")) def test_multioutput_custom_operator_quiet_custom_complexity(self): @@ -232,7 +232,7 @@ def test_empty_operators_single_input_warm_start(self): from pysr.sr import Main # We should have that the model state is now a Float32 hof: - Main.test_state = regressor.raw_julia_state_ + Main.test_state = regressor.sr_state_ self.assertTrue(Main.eval("typeof(test_state[2]).parameters[1] == Float32")) # This should exit almost immediately, and use the old equations regressor.fit(X, y) From ff2f93a6ede8b79f921d3b20171cd19330b34e4d Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Tue, 19 Sep 2023 17:53:51 +0100 Subject: [PATCH 07/17] Add missing sympy operators for boolean logic --- pysr/export_sympy.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pysr/export_sympy.py b/pysr/export_sympy.py index 81142f481..c3e2f7f0e 100644 --- a/pysr/export_sympy.py +++ b/pysr/export_sympy.py @@ -47,6 +47,8 @@ "ceil": sympy.ceiling, "sign": sympy.sign, "gamma": sympy.gamma, + "logical_or": lambda x, y: sympy.Piecewise((1.0, 0.0), ((x > 0) | (y > 0), True)), + "logical_and": lambda x, y: sympy.Piecewise((1.0, 0.0), ((x > 0) & (y > 0), True)), } From d5787b273aa8ac5ee983a269612dbb0d88664dd4 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Tue, 19 Sep 2023 17:56:24 +0100 Subject: [PATCH 08/17] Add missing sympy operators for relu --- pysr/export_sympy.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pysr/export_sympy.py b/pysr/export_sympy.py index c3e2f7f0e..c5f89c5e3 100644 --- a/pysr/export_sympy.py +++ b/pysr/export_sympy.py @@ -49,6 +49,7 @@ "gamma": sympy.gamma, "logical_or": lambda x, y: sympy.Piecewise((1.0, 0.0), ((x > 0) | (y > 0), True)), "logical_and": lambda x, y: sympy.Piecewise((1.0, 0.0), ((x > 0) & (y > 0), True)), + "relu": lambda x: sympy.Piecewise((0.0, x), (x < 0, True)), } From 47823bad9278483872fa88f50b8db8011b3ff70d Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 22 Sep 2023 15:48:04 +0100 Subject: [PATCH 09/17] Add functionality for piecewise export to torch --- pysr/export_torch.py | 69 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/pysr/export_torch.py b/pysr/export_torch.py index 7fcb67e82..31b8626b7 100644 --- a/pysr/export_torch.py +++ b/pysr/export_torch.py @@ -33,6 +33,70 @@ def _initialize_torch(): torch = _torch + # Allows PyTorch to map Piecewise functions: + def expr_cond_pair(expr, cond): + if isinstance(cond, torch.Tensor) and not isinstance(expr, torch.Tensor): + expr = torch.tensor(expr, dtype=cond.dtype, device=cond.device) + elif isinstance(expr, torch.Tensor) and not isinstance(cond, torch.Tensor): + cond = torch.tensor(cond, dtype=expr.dtype, device=expr.device) + else: + return expr, cond + + # First, make sure expr and cond are same size: + if expr.shape != cond.shape: + if len(expr.shape) == 0: + expr = expr.expand(cond.shape) + elif len(cond.shape) == 0: + cond = cond.expand(expr.shape) + else: + raise ValueError( + "expr and cond must have same shape, or one must be a scalar." + ) + return expr, cond + + def piecewise(*expr_conds): + output = None + already_used = None + for expr, cond in expr_conds: + if not isinstance(cond, torch.Tensor) and not isinstance( + expr, torch.Tensor + ): + # When we just have scalars, have to do this a bit more complicated + # due to the fact that we need to evaluate on the correct device. + if output is None: + already_used = cond + output = expr if cond else 0.0 + else: + if not isinstance(output, torch.Tensor): + output += expr if cond and not already_used else 0.0 + already_used = already_used or cond + else: + expr = torch.tensor( + expr, dtype=output.dtype, device=output.device + ).expand(output.shape) + output += torch.where( + cond & ~already_used, expr, torch.zeros_like(expr) + ) + already_used = already_used | cond + else: + if output is None: + already_used = cond + output = torch.where(cond, expr, torch.zeros_like(expr)) + else: + output += torch.where( + cond & ~already_used, expr, torch.zeros_like(expr) + ) + already_used = already_used | cond + return output + + def as_bool(x): + if isinstance(x, torch.Tensor): + return x.bool() + else: + return bool(x) + + # TODO: Add test that makes sure tensors are on the same device + _global_func_lookup = { sympy.Mul: _reduce(torch.mul), sympy.Add: _reduce(torch.add), @@ -81,6 +145,11 @@ def _initialize_torch(): sympy.Heaviside: torch.heaviside, sympy.core.numbers.Half: (lambda: 0.5), sympy.core.numbers.One: (lambda: 1.0), + sympy.logic.boolalg.Boolean: as_bool, + sympy.logic.boolalg.BooleanTrue: (lambda: True), + sympy.logic.boolalg.BooleanFalse: (lambda: False), + sympy.functions.elementary.piecewise.ExprCondPair: expr_cond_pair, + sympy.Piecewise: piecewise, } class _Node(torch.nn.Module): From 73d0f8a8172897c7c7b238a5cbd9621d7a76753e Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 22 Sep 2023 15:48:53 +0100 Subject: [PATCH 10/17] Clean up error message in exports --- pysr/export_jax.py | 2 +- pysr/export_torch.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/pysr/export_jax.py b/pysr/export_jax.py index e1730ca48..1a03b4545 100644 --- a/pysr/export_jax.py +++ b/pysr/export_jax.py @@ -69,7 +69,7 @@ def sympy2jaxtext(expr, parameters, symbols_in, extra_jax_mappings=None): _func = {**_jnp_func_lookup, **extra_jax_mappings}[expr.func] except KeyError: raise KeyError( - f"Function {expr.func} was not found in JAX function mappings." + f"Function {expr.func} was not found in JAX function mappings. " "Please add it to extra_jax_mappings in the format, e.g., " "{sympy.sqrt: 'jnp.sqrt'}." ) diff --git a/pysr/export_torch.py b/pysr/export_torch.py index 31b8626b7..2efed29ab 100644 --- a/pysr/export_torch.py +++ b/pysr/export_torch.py @@ -194,7 +194,7 @@ def __init__(self, *, expr, _memodict, _func_lookup, **kwargs): self._torch_func = _func_lookup[expr.func] except KeyError: raise KeyError( - f"Function {expr.func} was not found in Torch function mappings." + f"Function {expr.func} was not found in Torch function mappings. " "Please add it to extra_torch_mappings in the format, e.g., " "{sympy.sqrt: torch.sqrt}." ) @@ -222,7 +222,13 @@ def forward(self, memodict): arg_ = arg(memodict) memodict[arg] = arg_ args.append(arg_) - return self._torch_func(*args) + try: + return self._torch_func(*args) + except Exception as err: + # Add information about the current node to the error: + raise type(err)( + f"Error occurred in node {self._sympy_func} with args {args}" + ) class _SingleSymPyModule(torch.nn.Module): """SympyTorch code from https://github.com/patrick-kidger/sympytorch""" From f92a93508fc0ef91feb0e68e4a4fe0236cb35013 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 22 Sep 2023 15:49:08 +0100 Subject: [PATCH 11/17] Implement relu, logical_or, logical_and --- pysr/export_sympy.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pysr/export_sympy.py b/pysr/export_sympy.py index c5f89c5e3..a2f61e900 100644 --- a/pysr/export_sympy.py +++ b/pysr/export_sympy.py @@ -47,9 +47,9 @@ "ceil": sympy.ceiling, "sign": sympy.sign, "gamma": sympy.gamma, - "logical_or": lambda x, y: sympy.Piecewise((1.0, 0.0), ((x > 0) | (y > 0), True)), - "logical_and": lambda x, y: sympy.Piecewise((1.0, 0.0), ((x > 0) & (y > 0), True)), - "relu": lambda x: sympy.Piecewise((0.0, x), (x < 0, True)), + "logical_or": lambda x, y: sympy.Piecewise((1.0, (x > 0) | (y > 0)), (0.0, True)), + "logical_and": lambda x, y: sympy.Piecewise((1.0, (x > 0) & (y > 0)), (0.0, True)), + "relu": lambda x: sympy.Piecewise((0.0, x < 0), (x, True)), } From 2a20447f572a29a36a926484f1e06652741d3e52 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 22 Sep 2023 15:55:47 +0100 Subject: [PATCH 12/17] Remove unnecessary as_bool --- pysr/export_torch.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/pysr/export_torch.py b/pysr/export_torch.py index 2efed29ab..6b5849688 100644 --- a/pysr/export_torch.py +++ b/pysr/export_torch.py @@ -89,12 +89,6 @@ def piecewise(*expr_conds): already_used = already_used | cond return output - def as_bool(x): - if isinstance(x, torch.Tensor): - return x.bool() - else: - return bool(x) - # TODO: Add test that makes sure tensors are on the same device _global_func_lookup = { @@ -145,7 +139,7 @@ def as_bool(x): sympy.Heaviside: torch.heaviside, sympy.core.numbers.Half: (lambda: 0.5), sympy.core.numbers.One: (lambda: 1.0), - sympy.logic.boolalg.Boolean: as_bool, + sympy.logic.boolalg.Boolean: lambda x: x, sympy.logic.boolalg.BooleanTrue: (lambda: True), sympy.logic.boolalg.BooleanFalse: (lambda: False), sympy.functions.elementary.piecewise.ExprCondPair: expr_cond_pair, From 11dea32b0c7eaf7ffad493d1622fb8bc7c71c9a3 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Thu, 14 Dec 2023 11:56:41 -0600 Subject: [PATCH 13/17] Replace Heaviside with piecewise --- pysr/export_sympy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pysr/export_sympy.py b/pysr/export_sympy.py index 359ab232b..b14db0ee0 100644 --- a/pysr/export_sympy.py +++ b/pysr/export_sympy.py @@ -50,7 +50,7 @@ "max": lambda x, y: sympy.Piecewise((y, x < y), (x, True)), "min": lambda x, y: sympy.Piecewise((x, x < y), (y, True)), "round": lambda x: sympy.ceiling(x - 0.5), - "cond": lambda x, y: sympy.Heaviside(x, H0=0) * y, + "cond": lambda x, y: sympy.Piecewise((y, x > 0), (0.0, True)), "logical_or": lambda x, y: sympy.Piecewise((1.0, (x > 0) | (y > 0)), (0.0, True)), "logical_and": lambda x, y: sympy.Piecewise((1.0, (x > 0) & (y > 0)), (0.0, True)), "relu": lambda x: sympy.Piecewise((0.0, x < 0), (x, True)), From cff611ad85c96de739f6338847a3e9ce571c1811 Mon Sep 17 00:00:00 2001 From: Miles Cranmer Date: Mon, 3 Jun 2024 23:08:47 +0900 Subject: [PATCH 14/17] Apply suggestions from code review Co-authored-by: tbuckworth <55180288+tbuckworth@users.noreply.github.com> --- pysr/export_torch.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pysr/export_torch.py b/pysr/export_torch.py index 6b5849688..aebb0e148 100644 --- a/pysr/export_torch.py +++ b/pysr/export_torch.py @@ -84,9 +84,9 @@ def piecewise(*expr_conds): output = torch.where(cond, expr, torch.zeros_like(expr)) else: output += torch.where( - cond & ~already_used, expr, torch.zeros_like(expr) + cond.bool() & ~already_used, expr, torch.zeros_like(expr) ) - already_used = already_used | cond + already_used = already_used | cond.bool() return output # TODO: Add test that makes sure tensors are on the same device @@ -144,6 +144,7 @@ def piecewise(*expr_conds): sympy.logic.boolalg.BooleanFalse: (lambda: False), sympy.functions.elementary.piecewise.ExprCondPair: expr_cond_pair, sympy.Piecewise: piecewise, + sympy.logic.boolalg.ITE: if_then_else, } class _Node(torch.nn.Module): From 3f1524b74b772847598f56c34e5e201a8a0764a8 Mon Sep 17 00:00:00 2001 From: Miles Cranmer Date: Mon, 3 Jun 2024 23:09:00 +0900 Subject: [PATCH 15/17] Update pysr/export_torch.py Co-authored-by: tbuckworth <55180288+tbuckworth@users.noreply.github.com> --- pysr/export_torch.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pysr/export_torch.py b/pysr/export_torch.py index aebb0e148..3a3c2d539 100644 --- a/pysr/export_torch.py +++ b/pysr/export_torch.py @@ -54,6 +54,10 @@ def expr_cond_pair(expr, cond): ) return expr, cond +def if_then_else(*conds): + a, b, c = conds + return torch.where(a, torch.where(b, True, False), torch.where(c, True, False)) + def piecewise(*expr_conds): output = None already_used = None From 01e1a15d288e5155009033a8686e88a6a7c157ef Mon Sep 17 00:00:00 2001 From: Miles Cranmer Date: Tue, 4 Jun 2024 00:39:45 +0900 Subject: [PATCH 16/17] Update pysr/export_torch.py --- pysr/export_torch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pysr/export_torch.py b/pysr/export_torch.py index 3a3c2d539..8cf61e370 100644 --- a/pysr/export_torch.py +++ b/pysr/export_torch.py @@ -54,9 +54,9 @@ def expr_cond_pair(expr, cond): ) return expr, cond -def if_then_else(*conds): - a, b, c = conds - return torch.where(a, torch.where(b, True, False), torch.where(c, True, False)) + def if_then_else(*conds): + a, b, c = conds + return torch.where(a, torch.where(b, True, False), torch.where(c, True, False)) def piecewise(*expr_conds): output = None From 5c0a49a0827b287f7749a119ad24b16377452237 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 3 Jun 2024 15:40:13 +0000 Subject: [PATCH 17/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pysr/export_torch.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pysr/export_torch.py b/pysr/export_torch.py index 8cf61e370..0b9113d73 100644 --- a/pysr/export_torch.py +++ b/pysr/export_torch.py @@ -56,7 +56,9 @@ def expr_cond_pair(expr, cond): def if_then_else(*conds): a, b, c = conds - return torch.where(a, torch.where(b, True, False), torch.where(c, True, False)) + return torch.where( + a, torch.where(b, True, False), torch.where(c, True, False) + ) def piecewise(*expr_conds): output = None