Skip to content

Commit

Permalink
linting/formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
exs-whaddadin committed Oct 11, 2024
1 parent 6296acd commit e1e50c1
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 14 deletions.
1 change: 0 additions & 1 deletion src/molflux/modelzoo/models/lightning_gp/gp_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, cast

import datasets

from molflux.modelzoo.models.lightning.datamodule import LightningDataModule
from molflux.modelzoo.models.lightning_gp.gp_config import GPConfig

Expand Down
22 changes: 11 additions & 11 deletions src/molflux/modelzoo/models/lightning_gp/gp_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import logging
import os
from typing import Any, Dict, Optional, Union
from typing import Any, Optional, Union

from datasets import Dataset
import molflux
from datasets import Dataset

try:
import lightning.pytorch as pl
Expand Down Expand Up @@ -164,21 +164,21 @@ def from_dir(self, directory: str) -> None:

def _train_multi_data(
self,
train_data: Dict[Optional[str], Dataset],
train_data: dict[Optional[str], Dataset],
validation_data: Union[
Dict[Optional[str], Dataset],
dict[Optional[str], Dataset],
None,
] = None,
datamodule_config: Union[DataModuleConfig, Dict[str, Any], None] = None,
trainer_config: Union[TrainerConfig, Dict[str, Any], None] = None,
optimizer_config: Union[OptimizerConfig, Dict[str, Any], None] = None,
scheduler_config: Union[SchedulerConfig, Dict[str, Any], None] = None,
datamodule_config: Union[DataModuleConfig, dict[str, Any], None] = None,
trainer_config: Union[TrainerConfig, dict[str, Any], None] = None,
optimizer_config: Union[OptimizerConfig, dict[str, Any], None] = None,
scheduler_config: Union[SchedulerConfig, dict[str, Any], None] = None,
transfer_learning_config: Union[
TransferLearningConfigBase,
Dict[str, Any],
dict[str, Any],
None,
] = None,
compile_config: Union[CompileConfig, Dict[str, Any], bool, None] = None,
compile_config: Union[CompileConfig, dict[str, Any], bool, None] = None,
ckpt_path: Optional[str] = None,
**kwargs: Any,
) -> None:
Expand All @@ -191,7 +191,7 @@ def _train_multi_data(
strategy = molflux.splits.load_from_dict(
self.model_config.validation_config.splitting_strategy_config,
)
split_datasets: Dict[str, Dict[Any, Any]] = {
split_datasets: dict[str, dict[Any, Any]] = {
"train": {},
"validation": {},
}
Expand Down
4 changes: 2 additions & 2 deletions src/molflux/modelzoo/models/lightning_gp/gp_module.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Any, Optional, Tuple
from typing import Any, Optional

try:
import gpytorch
Expand Down Expand Up @@ -46,7 +46,7 @@ def __init__(self, model_config: GPConfig) -> None:
self.likelihood,
self.gp_model,
)
self.validation_data: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
self.validation_data: Optional[tuple[torch.Tensor, torch.Tensor]] = None
self.train_data_set = False

def forward(
Expand Down

0 comments on commit e1e50c1

Please sign in to comment.