Skip to content

Commit

Permalink
new optuna progress bar
Browse files Browse the repository at this point in the history
  • Loading branch information
RichieHakim committed Oct 31, 2024
1 parent 2559fdf commit 4b99134
Showing 1 changed file with 101 additions and 2 deletions.
103 changes: 101 additions & 2 deletions bnpm/optimization.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Dict, Type, Any, Union, Optional, Callable, Tuple, List
import time
import warnings
import itertools
import logging
import os

import numpy as np
import torch
import optuna
from tqdm.auto import tqdm

class Convergence_checker:
"""
Expand Down Expand Up @@ -320,3 +322,100 @@ def check(
if self.verbose:
print(f'Trial num: {self.num_trial}. Duration: {duration:.3f}s. Best value: {self.best:3e}. Current value:{trial.value:3e}') if self.verbose else None
self.num_trial += 1


class OptunaProgressBar:
"""
A customizable progress bar for Optuna's study.optimize().
Args:
n_trials (int, optional):
The number of trials. Required if timeout is not set.
timeout (float, optional):
The maximum time to run in seconds. Required if n_trials is not set.
tqdm_kwargs (dict, optional):
Additional keyword arguments to pass to tqdm.
"""

def __init__(self, n_trials=None, timeout=None, **tqdm_kwargs):
from optuna import logging as optuna_logging

self._n_trials = n_trials
self._timeout = timeout
self._tqdm_kwargs = tqdm_kwargs

# Read TQDM environment variables
self._load_env_variables()

# Initialize progress bar
self._progress_bar = None
self._last_elapsed_seconds = 0.0

# Setup logging handler to redirect Optuna logs to tqdm
self._tqdm_handler = _TqdmLoggingHandler()
self._tqdm_handler.setLevel(logging.INFO)
self._tqdm_handler.setFormatter(optuna_logging.create_default_formatter())
optuna_logging.disable_default_handler()
optuna_logging._get_library_root_logger().addHandler(self._tqdm_handler)

def __call__(self, study: optuna.study.Study, trial: optuna.trial.FrozenTrial):
# Initialize progress bar on first call
if self._progress_bar is None:
if self._n_trials is not None:
self._progress_bar = tqdm(total=self._n_trials, **self._tqdm_kwargs)
elif self._timeout is not None:
total = tqdm.format_interval(self._timeout)
fmt = "{desc} {percentage:3.0f}%|{bar}| {elapsed}/" + total
self._progress_bar = tqdm(total=self._timeout, bar_format=fmt, **self._tqdm_kwargs)
else:
raise ValueError("Either n_trials or timeout must be set.")

# Update progress bar
if self._n_trials is not None:
self._progress_bar.update(1)
elif self._timeout is not None:
elapsed = study._stop_watch.elapsed_time()
time_diff = elapsed - self._last_elapsed_seconds
self._progress_bar.update(time_diff)
self._last_elapsed_seconds = elapsed

# Update description with best trial information
if not study._is_multi_objective():
try:
best_value = study.best_value
self._progress_bar.set_description(f"Best value: {best_value:.6g}")
except ValueError:
pass # No trials completed yet

def close(self):
from optuna import logging as optuna_logging
if self._progress_bar is not None:
self._progress_bar.close()
# Restore Optuna's default logging handler
optuna_logging._get_library_root_logger().removeHandler(self._tqdm_handler)
optuna_logging.enable_default_handler()

def _load_env_variables(self):
# Load TQDM environment variables
env_vars = {
'disable': os.environ.get('TQDM_DISABLE', '').lower() == 'true',
'mininterval': float(os.environ.get('TQDM_MININTERVAL', 0.1)),
'maxinterval': float(os.environ.get('TQDM_MAXINTERVAL', 10.0)),
'miniters': int(os.environ.get('TQDM_MINITERS', 1)),
'smoothing': float(os.environ.get('TQDM_SMOOTHING', 0.3)),
}
# Update tqdm_kwargs with environment variables if not already set
for key, value in env_vars.items():
if key not in self._tqdm_kwargs:
self._tqdm_kwargs[key] = value
class _TqdmLoggingHandler(logging.StreamHandler):
"""Logging handler to redirect Optuna logs to tqdm.write()."""
def emit(self, record: logging.LogRecord) -> None:
try:
msg = self.format(record)
tqdm.write(msg)
self.flush()
except (KeyboardInterrupt, SystemExit):
raise
except Exception:
self.handleError(record)

0 comments on commit 4b99134

Please sign in to comment.