diff --git a/sage/trainer/callbacks.py b/sage/trainer/callbacks.py index 20df065..5f4e14a 100644 --- a/sage/trainer/callbacks.py +++ b/sage/trainer/callbacks.py @@ -55,4 +55,4 @@ def on_validation_epoch_end(self, trainer, pl_module): valid_loss = float(metrics["valid_loss"]) checkpoint_path = f"epoch{current_epoch}-valid_loss{valid_loss:.3f}.ckpt" trainer.save_checkpoint(save_dir / checkpoint_path, weights_only=True) - self.self.epoch = next(self.save_epochs) + self.epoch = next(self.save_epochs) diff --git a/sage/trainer/trainer.py b/sage/trainer/trainer.py index 431b3fe..12a3396 100644 --- a/sage/trainer/trainer.py +++ b/sage/trainer/trainer.py @@ -185,7 +185,7 @@ def forward(self, batch, mode: str = "train"): def log_result(self, output: dict, unit: str = "step", prog_bar: bool = False): output = {f"{unit}/{k}": float(v) for k, v in output.items()} - self.log_dict(dictionary=output, + self.log_dict(dictionary=output, on_step=unit == "step", on_epoch=unit == "epoch", prog_bar=prog_bar)