From a592703abce3d7c82cbae2fffffc91201801cdac Mon Sep 17 00:00:00 2001 From: Jannis Becktepe Date: Tue, 15 Oct 2024 17:11:37 +0200 Subject: [PATCH] fix: PBT self-destruct for directories + files --- hydra_plugins/hyper_pbt/hyper_pbt.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/hydra_plugins/hyper_pbt/hyper_pbt.py b/hydra_plugins/hyper_pbt/hyper_pbt.py index a034bf6..0fe3243 100644 --- a/hydra_plugins/hyper_pbt/hyper_pbt.py +++ b/hydra_plugins/hyper_pbt/hyper_pbt.py @@ -2,13 +2,15 @@ from __future__ import annotations +import os +import shutil + import numpy as np -from ConfigSpace.hyperparameters import ( - CategoricalHyperparameter, - NormalIntegerHyperparameter, - OrdinalHyperparameter, - UniformIntegerHyperparameter, -) +from ConfigSpace.hyperparameters import (CategoricalHyperparameter, + NormalIntegerHyperparameter, + OrdinalHyperparameter, + UniformIntegerHyperparameter) + from hydra_plugins.hypersweeper import Info @@ -171,12 +173,15 @@ def tell(self, info, value): def remove_checkpoints(self, iteration: int) -> None: """Remove checkpoints.""" - import os - # Delete all files in checkpoints dir starting with iteration_{iteration} for file in os.listdir(self.checkpoint_dir): if file.startswith(f"iteration_{iteration}"): - os.remove(os.path.join(self.checkpoint_dir, file)) + file_path = os.path.join(self.checkpoint_dir, file) + if os.path.isfile(file_path): + os.remove(file_path) + else: + shutil.rmtree(file_path) + def make_pbt(configspace, pbt_args): """Make a PBT instance for optimization."""