Skip to content

Commit

Permalink
test: improve coverage for new 1.0 code
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Nov 28, 2024
1 parent df2db4a commit 809bb74
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 7 deletions.
12 changes: 5 additions & 7 deletions pysr/expression_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

# For type checking purposes
if TYPE_CHECKING:
from .sr import PySRRegressor
from .sr import PySRRegressor # pragma: no cover

PySRRegressor: TypeAlias = PySRRegressor
PySRRegressor: TypeAlias = PySRRegressor # pragma: no cover
else:
PySRRegressor = NewType("PySRRegressor", Any)

Expand Down Expand Up @@ -47,12 +47,12 @@ class AbstractExpressionSpec(ABC):
@abstractmethod
def julia_expression_type(self) -> AnyValue:
"""The expression type"""
pass
pass # pragma: no cover

@abstractmethod
def julia_expression_options(self) -> AnyValue:
"""The expression options"""
pass
pass # pragma: no cover

@abstractmethod
def create_exports(
Expand All @@ -62,7 +62,7 @@ def create_exports(
search_output: Any,
) -> pd.DataFrame:
"""Create additional columns in the equations dataframe."""
pass
pass # pragma: no cover

@property
def evaluates_in_julia(self) -> bool:
Expand Down Expand Up @@ -247,8 +247,6 @@ def __init__(self, expression):
self.expression = expression

def __call__(self, X: np.ndarray, *args):
if not isinstance(X, np.ndarray):
raise ValueError("X must be a numpy array")
raw_output = self.expression(jl_array(X.T), *args)
return np.array(raw_output).T

Expand Down
6 changes: 6 additions & 0 deletions pysr/sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2364,6 +2364,10 @@ def sympy(self, index: Optional[Union[int, List[int]]] = None):
best_equation : str, list[str] of length nout_
SymPy representation of the best equation.
"""
if not self.expression_spec_.supports_sympy:
raise ValueError(
f"`expression_spec={self.expression_spec_}` does not support sympy export."
)
self.refresh()
best_equation = self.get_best(index=index)
if isinstance(best_equation, list):
Expand Down Expand Up @@ -2558,6 +2562,8 @@ def get_hof(
if should_read_from_file:
self.equation_file_contents_ = self._read_equation_file()

_validate_export_mappings(self.extra_jax_mappings, self.extra_torch_mappings)

equation_file_contents = cast(List[pd.DataFrame], self.equation_file_contents_)

ret_outputs = [
Expand Down
36 changes: 36 additions & 0 deletions pysr/test/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,23 @@ def test_template_expressions(self):
test_mse = np.mean((y_test - y_pred) ** 2)
self.assertLess(test_mse, 1e-5)

# Make sure that a nice error is raised if we try to get the sympy expression:
# f"`expression_spec={self.expression_spec_}` does not support sympy export."
with self.assertRaises(ValueError) as cm:
model.sympy()
self.assertRegex(
str(cm.exception),
r"`expression_spec=.*TemplateExpressionSpec.*` does not support sympy export.",
)
with self.assertRaises(ValueError):
model.latex()
with self.assertRaises(ValueError):
model.jax()
with self.assertRaises(ValueError):
model.pytorch()
with self.assertRaises(ValueError):
model.latex_table()

def test_parametric_expression(self):
# Create data with two classes
n_points = 100
Expand Down Expand Up @@ -593,6 +610,17 @@ def test_parametric_expression(self):
test_mse = np.mean((y_test - y_test_pred) ** 2)
self.assertLess(test_mse, 1e-3)

with self.assertRaises(ValueError):
model.sympy()
with self.assertRaises(ValueError):
model.latex()
with self.assertRaises(ValueError):
model.jax()
with self.assertRaises(ValueError):
model.pytorch()
with self.assertRaises(ValueError):
model.latex_table()


def manually_create_model(equations, feature_names=None):
if feature_names is None:
Expand Down Expand Up @@ -841,6 +869,14 @@ def test_deprecation(self):
# The correct value should be set:
self.assertEqual(model.fraction_replaced, 0.2)

with self.assertRaises(NotImplementedError):
model.equation_file_

with self.assertRaises(ValueError) as cm:
PySRRegressor.from_file(equation_file="", run_directory="")

self.assertIn("Passing `equation_file` is deprecated", str(cm.exception))

def test_deprecated_functions(self):
with self.assertWarns(FutureWarning):
install()
Expand Down

0 comments on commit 809bb74

Please sign in to comment.