diff --git a/diffusion/utils/config.py b/diffusion/utils/config.py index 209a076..37dc2f6 100644 --- a/diffusion/utils/config.py +++ b/diffusion/utils/config.py @@ -141,6 +141,7 @@ class TrainingConfig(BaseConfig): load_mask_index: bool = False snr_loss: bool = False real_prompt_ratio: float = 1.0 + training_hours: float = 10000.0 save_image_epochs: int = 1 save_model_epochs: int = 1 save_model_steps: int = 1000000 diff --git a/train_scripts/train.py b/train_scripts/train.py index 2f6e391..72a3ec0 100755 --- a/train_scripts/train.py +++ b/train_scripts/train.py @@ -448,7 +448,10 @@ def train(config, args, accelerator, model, optimizer, lr_scheduler, train_datal if loss_nan_timer > 20: raise ValueError("Loss is NaN too much times. Break here.") - if global_step % config.train.save_model_steps == 0 or (time.time() - training_start_time) / 3600 > 3.8: + if ( + global_step % config.train.save_model_steps == 0 + or (time.time() - training_start_time) / 3600 > config.train.training_hours + ): accelerator.wait_for_everyone() if accelerator.is_main_process: os.umask(0o000) @@ -469,7 +472,7 @@ def train(config, args, accelerator, model, optimizer, lr_scheduler, train_datal f.write(osp.join(config.work_dir, "config.py") + "\n") f.write(ckpt_saved_path) - if (time.time() - training_start_time) / 3600 > 3.8: + if (time.time() - training_start_time) / 3600 > config.train.training_hours: logger.info(f"Stopping training at epoch {epoch}, step {global_step} due to time limit.") return if config.train.visualize and (global_step % config.train.eval_sampling_steps == 0 or (step + 1) == 1):