Skip to content

Commit

Permalink
feat: use dtype dic for consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Bury committed Jun 16, 2023
1 parent 5f356dc commit 526f04b
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/arfs/gbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import gc
import os
from pathlib import Path
from arfs.utils import create_dtype_dict

QUAL_COLORS = [
(0.188235, 0.635294, 0.854902),
Expand Down Expand Up @@ -180,9 +181,9 @@ def fit(
raise KeyError("Provide the objective in the params dict")

if self.cat_feat == "auto":
self.cat_feat = list(
set(list(X.columns)) - set(list(X.select_dtypes(include=[np.number])))
)
dtypes_dic = create_dtype_dict(df=X, dic_keys="dtypes")
category_cols = dtypes_dic["cat"] + dtypes_dic["time"] + dtypes_dic["unk"]
self.cat_feat = category_cols if category_cols else None

if not isinstance(X, (pd.Series, pd.DataFrame)):
X = pd.DataFrame(X)
Expand Down

0 comments on commit 526f04b

Please sign in to comment.