From 4b9913448ae9a1230d321612b02d35b3ac82bb98 Mon Sep 17 00:00:00 2001 From: RichieHakim Date: Thu, 31 Oct 2024 13:57:00 -0400 Subject: [PATCH] new optuna progress bar --- bnpm/optimization.py | 103 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 101 insertions(+), 2 deletions(-) diff --git a/bnpm/optimization.py b/bnpm/optimization.py index 4c627f4..5a00efd 100644 --- a/bnpm/optimization.py +++ b/bnpm/optimization.py @@ -1,10 +1,12 @@ from typing import Dict, Type, Any, Union, Optional, Callable, Tuple, List import time -import warnings -import itertools +import logging +import os import numpy as np import torch +import optuna +from tqdm.auto import tqdm class Convergence_checker: """ @@ -320,3 +322,100 @@ def check( if self.verbose: print(f'Trial num: {self.num_trial}. Duration: {duration:.3f}s. Best value: {self.best:3e}. Current value:{trial.value:3e}') if self.verbose else None self.num_trial += 1 + + +class OptunaProgressBar: + """ + A customizable progress bar for Optuna's study.optimize(). + + Args: + n_trials (int, optional): + The number of trials. Required if timeout is not set. + timeout (float, optional): + The maximum time to run in seconds. Required if n_trials is not set. + tqdm_kwargs (dict, optional): + Additional keyword arguments to pass to tqdm. + """ + + def __init__(self, n_trials=None, timeout=None, **tqdm_kwargs): + from optuna import logging as optuna_logging + + self._n_trials = n_trials + self._timeout = timeout + self._tqdm_kwargs = tqdm_kwargs + + # Read TQDM environment variables + self._load_env_variables() + + # Initialize progress bar + self._progress_bar = None + self._last_elapsed_seconds = 0.0 + + # Setup logging handler to redirect Optuna logs to tqdm + self._tqdm_handler = _TqdmLoggingHandler() + self._tqdm_handler.setLevel(logging.INFO) + self._tqdm_handler.setFormatter(optuna_logging.create_default_formatter()) + optuna_logging.disable_default_handler() + optuna_logging._get_library_root_logger().addHandler(self._tqdm_handler) + + def __call__(self, study: optuna.study.Study, trial: optuna.trial.FrozenTrial): + # Initialize progress bar on first call + if self._progress_bar is None: + if self._n_trials is not None: + self._progress_bar = tqdm(total=self._n_trials, **self._tqdm_kwargs) + elif self._timeout is not None: + total = tqdm.format_interval(self._timeout) + fmt = "{desc} {percentage:3.0f}%|{bar}| {elapsed}/" + total + self._progress_bar = tqdm(total=self._timeout, bar_format=fmt, **self._tqdm_kwargs) + else: + raise ValueError("Either n_trials or timeout must be set.") + + # Update progress bar + if self._n_trials is not None: + self._progress_bar.update(1) + elif self._timeout is not None: + elapsed = study._stop_watch.elapsed_time() + time_diff = elapsed - self._last_elapsed_seconds + self._progress_bar.update(time_diff) + self._last_elapsed_seconds = elapsed + + # Update description with best trial information + if not study._is_multi_objective(): + try: + best_value = study.best_value + self._progress_bar.set_description(f"Best value: {best_value:.6g}") + except ValueError: + pass # No trials completed yet + + def close(self): + from optuna import logging as optuna_logging + if self._progress_bar is not None: + self._progress_bar.close() + # Restore Optuna's default logging handler + optuna_logging._get_library_root_logger().removeHandler(self._tqdm_handler) + optuna_logging.enable_default_handler() + + def _load_env_variables(self): + # Load TQDM environment variables + env_vars = { + 'disable': os.environ.get('TQDM_DISABLE', '').lower() == 'true', + 'mininterval': float(os.environ.get('TQDM_MININTERVAL', 0.1)), + 'maxinterval': float(os.environ.get('TQDM_MAXINTERVAL', 10.0)), + 'miniters': int(os.environ.get('TQDM_MINITERS', 1)), + 'smoothing': float(os.environ.get('TQDM_SMOOTHING', 0.3)), + } + # Update tqdm_kwargs with environment variables if not already set + for key, value in env_vars.items(): + if key not in self._tqdm_kwargs: + self._tqdm_kwargs[key] = value +class _TqdmLoggingHandler(logging.StreamHandler): + """Logging handler to redirect Optuna logs to tqdm.write().""" + def emit(self, record: logging.LogRecord) -> None: + try: + msg = self.format(record) + tqdm.write(msg) + self.flush() + except (KeyboardInterrupt, SystemExit): + raise + except Exception: + self.handleError(record) \ No newline at end of file