diff --git a/micromind/core.py b/micromind/core.py index 835b1ec..2f887be 100644 --- a/micromind/core.py +++ b/micromind/core.py @@ -527,7 +527,6 @@ def train( self.accelerator.backward(loss) self.opt.step() - loss_epoch += loss.item() if hasattr(self, "lr_sched"): # ok for cos_lr self.lr_sched.step() diff --git a/micromind/utils/helpers.py b/micromind/utils/helpers.py index 43b51d7..1522857 100644 --- a/micromind/utils/helpers.py +++ b/micromind/utils/helpers.py @@ -31,8 +31,11 @@ def override_conf(hparams: Dict): """ parser = argparse.ArgumentParser(description="MicroMind experiment configuration.") for key, value in hparams.items(): - parser.add_argument(f"--{key}", type=type(value), default=value) - + parser.add_argument( + f"--{key}", + type=str2bool if isinstance(value, bool) else type(value), + default=value, + ) args, extra_args = parser.parse_known_args() for key, value in vars(args).items(): if value is not None: @@ -78,3 +81,14 @@ def get_logger(): logger.add(sys.stderr, format=fmt) return logger + + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.")