From 3dad20837c885cbf31a8f63f4b93bf2b7904a4a8 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Thu, 28 Nov 2024 13:00:46 +0000 Subject: [PATCH] fix: only convert dtype for Julia evaluations --- pysr/expression_specs.py | 12 ++++++++++++ pysr/sr.py | 5 ++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/pysr/expression_specs.py b/pysr/expression_specs.py index c7ac5327..a4c77189 100644 --- a/pysr/expression_specs.py +++ b/pysr/expression_specs.py @@ -63,6 +63,10 @@ def create_exports( """Create additional columns in the equations dataframe.""" pass + @property + def evaluates_in_julia(self) -> bool: + return False + @property def supports_sympy(self) -> bool: return False @@ -197,6 +201,10 @@ def julia_expression_options(self): ) return creator(self.function_symbols, f_combine, self.num_features) + @property + def evaluates_in_julia(self): + return True + def create_exports( self, model: PySRRegressor, @@ -219,6 +227,10 @@ def julia_expression_type(self): def julia_expression_options(self): return jl.seval("NamedTuple{(:max_parameters,)}")((self.max_parameters,)) + @property + def evaluates_in_julia(self): + return True + def create_exports( self, model: PySRRegressor, diff --git a/pysr/sr.py b/pysr/sr.py index 7fb7da01..e644d8f5 100644 --- a/pysr/sr.py +++ b/pysr/sr.py @@ -2313,7 +2313,10 @@ def predict( # feature selected) X in fit. X = X.reindex(columns=self.feature_names_in_) X = self._validate_data_X(X) - X = X.astype(self._get_precision_mapped_dtype(X)) + if self.expression_spec_.evaluates_in_julia: + # Julia wants the right dtype + X = X.astype(self._get_precision_mapped_dtype(X)) + if category is not None: offset_for_julia_indexing = 1 args: tuple = (