diff --git a/bnpm/automatic_regression.py b/bnpm/automatic_regression.py index 2d2903c..d1bb485 100644 --- a/bnpm/automatic_regression.py +++ b/bnpm/automatic_regression.py @@ -169,7 +169,7 @@ def __init__( 'kwargs_convergence': self.kwargs_convergence, 'n_jobs_optuna': self.n_jobs_optuna, } - self.callback_wandb = optuna.integration.WeightsAndBiasesCallback( + callback_wandb = optuna.integration.WeightsAndBiasesCallback( metric_name="loss", wandb_kwargs={ 'project': self.wandb_project, @@ -177,6 +177,13 @@ def __init__( }, as_multirun=False, ) + ## Make a safe version of the callback by putting it in a try-except block + def safe_callback(study, trial): + try: + callback_wandb(study, trial) + except Exception as e: + print(f'Error in wandb callback: {e}') + self.callback_wandb = safe_callback else: self.callback_wandb = None @@ -295,7 +302,8 @@ def _objective(self, trial: optuna.trial.Trial) -> float: self.model_best = model self.params_best = kwargs_model - return loss + # return loss + return np.nan def fit(self) -> Union[sklearn.base.BaseEstimator, Optional[Dict[str, Any]]]: """