Skip to content

Commit

Permalink
style: 💄 apply black to py and NB
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Bury committed Jul 29, 2023
1 parent ba01f43 commit 9de6795
Show file tree
Hide file tree
Showing 19 changed files with 862 additions and 361 deletions.
4 changes: 2 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@
# a list of builtin themes.
#
# html_theme = "sphinx_rtd_theme"
html_permalinks_icon = '<span>#</span>'
html_theme = 'sphinxawesome_theme'
html_permalinks_icon = "<span>#</span>"
html_theme = "sphinxawesome_theme"

# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
Expand Down
20 changes: 15 additions & 5 deletions docs/notebooks/arfs_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,12 @@
"model = clone(model)\n",
"\n",
"feat_selector = arfsgroot.Leshy(\n",
" model, n_estimators=50, verbose=1, max_iter=10, random_state=42, importance=\"fastshap\"\n",
" model,\n",
" n_estimators=50,\n",
" verbose=1,\n",
" max_iter=10,\n",
" random_state=42,\n",
" importance=\"fastshap\",\n",
")\n",
"feat_selector.fit(X, y, sample_weight=None)\n",
"print(f\"The selected features: {feat_selector.get_feature_names_out()}\")\n",
Expand Down Expand Up @@ -1946,7 +1951,7 @@
" CatBoostClassifier(random_state=42, verbose=0),\n",
" LGBMClassifier(random_state=42, verbose=-1),\n",
" LightForestClassifier(n_feat=X.shape[1]),\n",
" XGBClassifier(random_state=42, verbosity=0, eval_metric='logloss'),\n",
" XGBClassifier(random_state=42, verbosity=0, eval_metric=\"logloss\"),\n",
"]\n",
"\n",
"feat_selector = arfsgroot.Leshy(\n",
Expand Down Expand Up @@ -1987,10 +1992,14 @@
"from xgboost import XGBClassifier\n",
"from fasttreeshap import TreeExplainer as FastTreeExplainer\n",
"\n",
"X, y = make_classification(n_samples=1000, n_features=10, n_informative=8, random_state=8)\n",
"model = XGBClassifier() \n",
"X, y = make_classification(\n",
" n_samples=1000, n_features=10, n_informative=8, random_state=8\n",
")\n",
"model = XGBClassifier()\n",
"model.fit(X, y)\n",
"explainer = FastTreeExplainer(model, algorithm=\"auto\", shortcut=False, feature_perturbation=\"tree_path_dependent\")\n",
"explainer = FastTreeExplainer(\n",
" model, algorithm=\"auto\", shortcut=False, feature_perturbation=\"tree_path_dependent\"\n",
")\n",
"shap_matrix = explainer.shap_values(X)"
]
},
Expand Down Expand Up @@ -3036,6 +3045,7 @@
"source": [
"# Leshy\n",
"from arfs.preprocessing import OrdinalEncoderPandas\n",
"\n",
"model = LGBMClassifier(random_state=42, verbose=-1, n_estimators=10)\n",
"X_encoded = OrdinalEncoderPandas().fit_transform(X=X)\n",
"feat_selector = arfsgroot.Leshy(\n",
Expand Down
46 changes: 38 additions & 8 deletions docs/notebooks/arfs_grootcv_custom_params.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,21 @@
"import os\n",
"import multiprocessing\n",
"\n",
"\n",
"def get_physical_cores():\n",
" if os.name == 'posix': # For Unix-based systems (e.g., Linux, macOS)\n",
" if os.name == \"posix\": # For Unix-based systems (e.g., Linux, macOS)\n",
" try:\n",
" return os.sysconf(\"SC_NPROCESSORS_ONLN\")\n",
" except ValueError:\n",
" pass\n",
" elif os.name == 'nt': # For Windows\n",
" elif os.name == \"nt\": # For Windows\n",
" try:\n",
" return int(os.environ[\"NUMBER_OF_PROCESSORS\"])\n",
" except (ValueError, KeyError):\n",
" pass\n",
" return multiprocessing.cpu_count()\n",
"\n",
"\n",
"num_physical_cores = get_physical_cores()\n",
"print(f\"Number of physical cores: {num_physical_cores}\")"
]
Expand Down Expand Up @@ -508,12 +510,19 @@
"source": [
"for n_jobs in range(num_physical_cores):\n",
" start_time = time.time()\n",
" feat_selector = GrootCV(objective=\"rmse\", cutoff=1, n_folds=5, n_iter=5, silent=True, fastshap=False, n_jobs=n_jobs)\n",
" feat_selector = GrootCV(\n",
" objective=\"rmse\",\n",
" cutoff=1,\n",
" n_folds=5,\n",
" n_iter=5,\n",
" silent=True,\n",
" fastshap=False,\n",
" n_jobs=n_jobs,\n",
" )\n",
" feat_selector.fit(X, y, sample_weight=None)\n",
" end_time = time.time()\n",
" execution_time = end_time - start_time\n",
" print(f\"n_jobs = {n_jobs}, Execution time: {execution_time:.3f} seconds\")\n",
"\n"
" print(f\"n_jobs = {n_jobs}, Execution time: {execution_time:.3f} seconds\")"
]
},
{
Expand Down Expand Up @@ -572,7 +581,14 @@
"source": [
"# GrootCV with less regularization\n",
"feat_selector = GrootCV(\n",
" objective=\"rmse\", cutoff=1, n_folds=5, n_iter=5, silent=True, fastshap=True, n_jobs=0, lgbm_params={\"min_data_in_leaf\": 10}\n",
" objective=\"rmse\",\n",
" cutoff=1,\n",
" n_folds=5,\n",
" n_iter=5,\n",
" silent=True,\n",
" fastshap=True,\n",
" n_jobs=0,\n",
" lgbm_params={\"min_data_in_leaf\": 10},\n",
")\n",
"feat_selector.fit(X, y, sample_weight=None)\n",
"print(f\"The selected features: {feat_selector.get_feature_names_out()}\")\n",
Expand Down Expand Up @@ -628,7 +644,14 @@
"source": [
"# GrootCV with default regularization\n",
"feat_selector = GrootCV(\n",
" objective=\"rmse\", cutoff=1, n_folds=5, n_iter=5, silent=True, fastshap=True, n_jobs=0, lgbm_params=None\n",
" objective=\"rmse\",\n",
" cutoff=1,\n",
" n_folds=5,\n",
" n_iter=5,\n",
" silent=True,\n",
" fastshap=True,\n",
" n_jobs=0,\n",
" lgbm_params=None,\n",
")\n",
"feat_selector.fit(X, y, sample_weight=None)\n",
"print(f\"The selected features: {feat_selector.get_feature_names_out()}\")\n",
Expand Down Expand Up @@ -684,7 +707,14 @@
"source": [
"# GrootCV with larger regularization\n",
"feat_selector = GrootCV(\n",
" objective=\"rmse\", cutoff=1, n_folds=5, n_iter=5, silent=True, fastshap=True, n_jobs=0, lgbm_params={\"min_data_in_leaf\": 100}\n",
" objective=\"rmse\",\n",
" cutoff=1,\n",
" n_folds=5,\n",
" n_iter=5,\n",
" silent=True,\n",
" fastshap=True,\n",
" n_jobs=0,\n",
" lgbm_params={\"min_data_in_leaf\": 100},\n",
")\n",
"feat_selector.fit(X, y, sample_weight=None)\n",
"print(f\"The selected features: {feat_selector.get_feature_names_out()}\")\n",
Expand Down
Loading

0 comments on commit 9de6795

Please sign in to comment.