Skip to content

Commit

Permalink
add new checkpointer logic and hooks for save_state
Browse files Browse the repository at this point in the history
  • Loading branch information
fpaissan committed Nov 23, 2023
1 parent 5c8de1f commit 995462e
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 38 deletions.
82 changes: 46 additions & 36 deletions micromind/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@

from accelerate import Accelerator
import torch
import os

from .utils.helpers import get_random_string
from .utils.checkpointer import Checkpointer

# This is used ONLY if you are not using argparse to get the hparams
default_cfg = {
Expand Down Expand Up @@ -156,8 +152,8 @@ def __init__(self, hparams=None):
self.hparams = hparams
self.input_shape = None

self.device = "cpu" # used just to init the models
self.accelerator = Accelerator()
self.device = self.accelerator.device

self.current_epoch = 0

Expand Down Expand Up @@ -306,44 +302,55 @@ def on_train_start(self):
This function gets executed at the beginning of every training.
"""
self.experiment_folder = os.path.join(
self.hparams.output_folder, self.hparams.experiment_name
)
if self.hparams.debug:
self.experiment_folder = "tmp_" + get_random_string()
logger.info(f"Created temporary folder for debug {self.experiment_folder}.")

accelerate_dir = os.path.join(self.experiment_folder, "save")
if os.path.exists(accelerate_dir):
self.opt, self.lr_sched = self.configure_optimizers()

init_opt = self.configure_optimizers()
if isinstance(init_opt, list) or isinstance(init_opt, tuple):
self.opt, self.lr_sched = init_opt
else:
os.makedirs(self.experiment_folder, exist_ok=True)
self.opt = init_opt

self.opt, self.lr_sched = self.configure_optimizers()
self.start_epoch = 0
self.init_devices()

if self.checkpointer is not None:
# recover state
# self.checkpointer.recover_state()
pass

# handle start_epoch better
self.start_epoch = 0
self.checkpointer = Checkpointer(
"val_loss",
checkpoint_path=self.experiment_folder,
accelerator=self.accelerator,
)

self.accelerator = Accelerator()
self.device = self.accelerator.device
self.modules.to(self.device)
print("Set device to ", self.device)
def init_devices(self):
"""Initializes the data pipeline and modules for DDP and accelerated inference.
To control the device selection, use `accelerate config`."""

convert = [self.modules]
if hasattr(self, "opt"):
convert += [self.opt]

if hasattr(self, "lr_sched"):
convert += [self.lr_sched]

if hasattr(self, "datasets"):
# if the datasets are store here, prepare them for DDP
convert += list(self.datasets.values())

convert = [self.modules, self.opt, self.lr_sched] + list(self.datasets.values())
accelerated = self.accelerator.prepare(convert)
self.modules, self.opt, self.lr_sched = accelerated[:3]
for i, key in enumerate(list(self.datasets.keys())[::-1]):
self.datasets[key] = accelerated[-(i + 1)]
self.modules = accelerated[0]
self.accelerator.register_for_checkpointing(self.modules)

if hasattr(self, "opt"):
self.opt = accelerated[1]
self.accelerator.register_for_checkpointing(self.opt)

if os.path.exists(accelerate_dir):
self.accelerator.load_state(accelerate_dir)
if hasattr(self, "lr_sched"):
self.lr_sched = accelerated[2]
self.accelerator.register_for_checkpointing(self.lr_sched)

if hasattr(self, "datasets"):
for i, key in enumerate(list(self.datasets.keys())[::-1]):
self.datasets[key] = accelerated[-(i + 1)]

self.modules.to(self.device)

def on_train_end(self):
"""Runs at the end of each training. Cleans up before exiting."""
Expand All @@ -359,6 +366,7 @@ def train(
epochs: int = 1,
datasets: Dict = {},
metrics: List[Metric] = [],
checkpointer=None, # fix type hints
debug: bool = False,
) -> None:
"""
Expand All @@ -384,6 +392,7 @@ def train(
"""
self.datasets = datasets
self.metrics = metrics
self.checkpointer = checkpointer
assert "train" in self.datasets, "Training dataloader was not specified."
assert epochs > 0, "You must specify at least one epoch."

Expand Down Expand Up @@ -455,13 +464,14 @@ def train(

if "val" in datasets:
val_metrics = self.validate()
if self.accelerator.is_local_main_process:
if (
self.accelerator.is_local_main_process
and self.checkpointer is not None
):
self.checkpointer(
self,
e,
train_metrics,
val_metrics,
lambda x: self.accelerator.unwrap_model(x),
)
else:
val_metrics = train_metrics.update(
Expand Down
3 changes: 1 addition & 2 deletions micromind/utils/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def dump_modules(modules, out_folder):
def dump_status(status, out_dir):
yaml_status = yaml.dump(status)

print(out_dir)
with open(os.path.join(out_dir, "status.yaml"), "w") as f:
f.write(yaml_status)

Expand Down Expand Up @@ -95,7 +94,7 @@ def __call__(
self.dump_status(status_dict, current_folder)

# remove previous last dir after saving the current version
if os.path.exists(self.last_dir):
if os.path.exists(self.last_dir) and self.last_dir != self.check_paths:
shutil.rmtree(self.last_dir)

self.last_dir = current_folder
Expand Down

0 comments on commit 995462e

Please sign in to comment.