diff --git a/.gitignore b/.gitignore index 103eee6..05591bd 100644 --- a/.gitignore +++ b/.gitignore @@ -127,6 +127,7 @@ venv/ ENV/ env.bak/ venv.bak/ +.micromamba # Spyder project settings .spyderproject diff --git a/pl_crossvalidate/ensemble.py b/pl_crossvalidate/ensemble.py index fb613f6..ea1f795 100644 --- a/pl_crossvalidate/ensemble.py +++ b/pl_crossvalidate/ensemble.py @@ -30,8 +30,7 @@ class EnsembleLightningModule(LightningModule): def __init__(self, model: LightningModule, ckpt_paths: List[str]) -> None: super().__init__() - model_cls = type(model) - self.models = nn.ModuleList([model_cls.load_from_checkpoint(p) for p in ckpt_paths]) + self.models = nn.ModuleList([type(model).load_from_checkpoint(p) for p in ckpt_paths]) # We need to set the trainer to something to avoid errors model._trainer = object() diff --git a/pl_crossvalidate/trainer.py b/pl_crossvalidate/trainer.py index 86426be..5c374f3 100644 --- a/pl_crossvalidate/trainer.py +++ b/pl_crossvalidate/trainer.py @@ -218,16 +218,16 @@ def out_of_sample_score( # temporarily replace the predict_step method with the score method to use the trainer.predict method _orig_predict_method = model.predict_step - model.predict_step = model.score # run prection on each fold outputs = [] for i, ckpt_path in enumerate(ckpt_paths): self._set_fold_index(i, datamodule=datamodule) - model.load_from_checkpoint(ckpt_path) + model = type(model).load_from_checkpoint(ckpt_path) + model.predict_step = score_method out = self.predict(model=model, dataloaders=datamodule.test_dataloader()) + model.predict_step = _orig_predict_method outputs.append(torch.cat(out, 0)) - model.predict_step = _orig_predict_method # reorder to match the order of the dataset test_indices = torch.cat([torch.tensor(test) for _, test in datamodule.splits]) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index efc6fa1..4d8603a 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -30,7 +30,7 @@ def test_trainer_initialization(arguments, expected): @pytest.mark.parametrize("accelerator", ["cpu", "gpu"]) def test_cross_validate(tmp_path, accelerator): """Test cross validation finish a basic run.""" - if not torch.cuda.is_available() and torch.cuda.device_count() < 1: + if accelerator == "gpu" and not torch.cuda.is_available() and torch.cuda.device_count() < 1: pytest.skip("test requires cuda support") model = BoringModel()