Skip to content

Commit

Permalink
fix: only convert dtype for Julia evaluations
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Nov 28, 2024
1 parent 9ce96a7 commit 3dad208
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
12 changes: 12 additions & 0 deletions pysr/expression_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ def create_exports(
"""Create additional columns in the equations dataframe."""
pass

@property
def evaluates_in_julia(self) -> bool:
return False

@property
def supports_sympy(self) -> bool:
return False
Expand Down Expand Up @@ -197,6 +201,10 @@ def julia_expression_options(self):
)
return creator(self.function_symbols, f_combine, self.num_features)

@property
def evaluates_in_julia(self):
return True

def create_exports(
self,
model: PySRRegressor,
Expand All @@ -219,6 +227,10 @@ def julia_expression_type(self):
def julia_expression_options(self):
return jl.seval("NamedTuple{(:max_parameters,)}")((self.max_parameters,))

@property
def evaluates_in_julia(self):
return True

def create_exports(
self,
model: PySRRegressor,
Expand Down
5 changes: 4 additions & 1 deletion pysr/sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2313,7 +2313,10 @@ def predict(
# feature selected) X in fit.
X = X.reindex(columns=self.feature_names_in_)
X = self._validate_data_X(X)
X = X.astype(self._get_precision_mapped_dtype(X))
if self.expression_spec_.evaluates_in_julia:
# Julia wants the right dtype
X = X.astype(self._get_precision_mapped_dtype(X))

if category is not None:
offset_for_julia_indexing = 1
args: tuple = (
Expand Down

0 comments on commit 3dad208

Please sign in to comment.