Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update versions for pre-commit hooks #41

Merged
merged 1 commit into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ repos:
- id: check-toml

- repo: https://github.com/python-poetry/poetry
rev: 1.8.4
rev: 2.0.1
hooks:
- id: poetry-check
args: [--lock]

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.2
rev: v0.9.2
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand All @@ -31,13 +31,13 @@ repos:
types_or: [ python, pyi, jupyter ]

- repo: https://github.com/crate-ci/typos
rev: v1.27.0
rev: v1.29.4
hooks:
- id: typos
args: []

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.13.0
rev: v1.14.1
hooks:
- id: mypy
entry: mypy
Expand All @@ -46,7 +46,7 @@ repos:
exclude: tests|projects

- repo: https://github.com/nbQA-dev/nbQA
rev: 1.9.0
rev: 1.9.1
hooks:
- id: nbqa-ruff
args: [--fix, --exit-non-zero-on-fix]
Expand Down
6 changes: 3 additions & 3 deletions mmlearn/cli/_instantiators.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ def instantiate_sampler(
kwargs.update(distributed_sampler_kwargs)

sampler = hydra.utils.instantiate(cfg, **kwargs)
assert isinstance(
sampler, Sampler
), f"Expected a `torch.utils.data.Sampler` object but got {type(sampler)}."
assert isinstance(sampler, Sampler), (
f"Expected a `torch.utils.data.Sampler` object but got {type(sampler)}."
)

if sampler is None and requires_distributed_sampler:
sampler = DistributedSampler(dataset, **distributed_sampler_kwargs)
Expand Down
18 changes: 9 additions & 9 deletions mmlearn/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def main(cfg: MMLearnConf) -> None: # noqa: PLR0912
trainer: Trainer = hydra.utils.instantiate(
cfg.trainer, callbacks=callbacks, logger=loggers, _convert_="all"
)
assert isinstance(
trainer, Trainer
), "Trainer must be an instance of `lightning.pytorch.trainer.Trainer`"
assert isinstance(trainer, Trainer), (
"Trainer must be an instance of `lightning.pytorch.trainer.Trainer`"
)

if rank_zero_only.rank == 0 and loggers is not None: # update wandb config
for trainer_logger in loggers:
Expand All @@ -79,9 +79,9 @@ def main(cfg: MMLearnConf) -> None: # noqa: PLR0912
# prepare dataloaders
if cfg.job_type == JobType.train:
train_dataset = instantiate_datasets(cfg.datasets.train)
assert (
train_dataset is not None
), "Train dataset (`cfg.datasets.train`) is required for training."
assert train_dataset is not None, (
"Train dataset (`cfg.datasets.train`) is required for training."
)

train_sampler = instantiate_sampler(
cfg.dataloader.train.get("sampler"),
Expand Down Expand Up @@ -109,9 +109,9 @@ def main(cfg: MMLearnConf) -> None: # noqa: PLR0912
)
else:
test_dataset = instantiate_datasets(cfg.datasets.test)
assert (
test_dataset is not None
), "Test dataset (`cfg.datasets.test`) is required for evaluation."
assert test_dataset is not None, (
"Test dataset (`cfg.datasets.test`) is required for evaluation."
)

test_sampler = instantiate_sampler(
cfg.dataloader.test.get("sampler"),
Expand Down
12 changes: 6 additions & 6 deletions mmlearn/datasets/chexpert.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ def __init__(
transform: Optional[Callable[[Image.Image], torch.Tensor]] = None,
) -> None:
assert split in ["train", "valid"], f"split {split} is not available."
assert (
labeler in ["chexpert", "chexbert", "vchexbert"] or labeler is None
), f"labeler {labeler} is not available."
assert (
callable(transform) or transform is None
), "transform is not callable or None."
assert labeler in ["chexpert", "chexbert", "vchexbert"] or labeler is None, (
f"labeler {labeler} is not available."
)
assert callable(transform) or transform is None, (
"transform is not callable or None."
)

if split == "valid":
data_file = f"{split}_data.json"
Expand Down
2 changes: 1 addition & 1 deletion mmlearn/datasets/core/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _merge_examples(examples: list[Example]) -> dict[str, Any]:
else:
merged_examples[key] = [example[key]]

for key in merged_examples:
for key in merged_examples: # noqa: PLC0206
if isinstance(merged_examples[key][0], Example):
merged_examples[key] = _merge_examples(merged_examples[key])

Expand Down
6 changes: 3 additions & 3 deletions mmlearn/datasets/librispeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ def __len__(self) -> int:
def __getitem__(self, idx: int) -> Example:
"""Return an example from the dataset."""
waveform, sample_rate, transcript, _, _, _ = self.dataset[idx]
assert (
sample_rate == SAMPLE_RATE
), f"Expected sample rate to be `16000`, got {sample_rate}."
assert sample_rate == SAMPLE_RATE, (
f"Expected sample rate to be `16000`, got {sample_rate}."
)
waveform = pad_or_trim(waveform.flatten())

return Example(
Expand Down
6 changes: 3 additions & 3 deletions mmlearn/datasets/nihcxr.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ def __init__(
transform: Optional[Callable[[Image.Image], torch.Tensor]] = None,
) -> None:
assert split in ["train", "test", "bbox"], f"split {split} is not available."
assert (
callable(transform) or transform is None
), "transform is not callable or None."
assert callable(transform) or transform is None, (
"transform is not callable or None."
)

data_path = os.path.join(root_dir, split + "_data.json")

Expand Down
6 changes: 3 additions & 3 deletions mmlearn/modules/encoders/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,9 +512,9 @@ def forward(
masks: Union[torch.Tensor, list[torch.Tensor]],
) -> torch.Tensor:
"""Forward pass through the Vision Transformer Predictor."""
assert (masks is not None) and (
masks_x is not None
), "Cannot run predictor without mask indices"
assert (masks is not None) and (masks_x is not None), (
"Cannot run predictor without mask indices"
)

if not isinstance(masks_x, list):
masks_x = [masks_x]
Expand Down
6 changes: 3 additions & 3 deletions mmlearn/tasks/contrastive_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,9 @@ def __init__( # noqa: PLR0912, PLR0915
Modalities.get_modality(modality_key)
for modality_key in modality_encoder_mapping
]
assert (
len(self._available_modalities) >= 2
), "Expected at least two modalities to be available. "
assert len(self._available_modalities) >= 2, (
"Expected at least two modalities to be available. "
)

#: A :py:class:`~torch.nn.ModuleDict`, where the keys are the names of the
#: modalities and the values are the encoder modules.
Expand Down
6 changes: 3 additions & 3 deletions projects/med_benchmarking/datasets/mimiciv_cxr.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,9 @@ def __getitem__(self, idx: int) -> Example:
)
if tokens is not None:
if isinstance(tokens, dict): # output of HFTokenizer
assert (
Modalities.TEXT.name in tokens
), f"Missing key `{Modalities.TEXT.name}` in tokens."
assert Modalities.TEXT.name in tokens, (
f"Missing key `{Modalities.TEXT.name}` in tokens."
)
example.update(tokens)
else:
example[Modalities.TEXT.name] = tokens
Expand Down
6 changes: 3 additions & 3 deletions projects/med_benchmarking/datasets/pmcoa.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ def __getitem__(self, idx: int) -> Example:
tokens = self.tokenizer(caption) if self.tokenizer is not None else None
if tokens is not None:
if isinstance(tokens, dict): # output of HFTokenizer
assert (
Modalities.TEXT.name in tokens
), f"Missing key `{Modalities.TEXT.name}` in tokens."
assert Modalities.TEXT.name in tokens, (
f"Missing key `{Modalities.TEXT.name}` in tokens."
)
example.update(tokens)
else:
example[Modalities.TEXT.name] = tokens
Expand Down
6 changes: 3 additions & 3 deletions projects/med_benchmarking/datasets/quilt.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,9 @@ def __getitem__(self, idx: int) -> Example:

if tokens is not None:
if isinstance(tokens, dict): # output of HFTokenizer
assert (
Modalities.TEXT.name in tokens
), f"Missing key `{Modalities.TEXT.name}` in tokens."
assert Modalities.TEXT.name in tokens, (
f"Missing key `{Modalities.TEXT.name}` in tokens."
)
example.update(tokens)
else:
example[Modalities.TEXT.name] = tokens
Expand Down
6 changes: 3 additions & 3 deletions projects/med_benchmarking/datasets/roco.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ def __getitem__(self, idx: int) -> Example:

if tokens is not None:
if isinstance(tokens, dict): # output of HFTokenizer
assert (
Modalities.TEXT.name in tokens
), f"Missing key `{Modalities.TEXT.name}` in tokens."
assert Modalities.TEXT.name in tokens, (
f"Missing key `{Modalities.TEXT.name}` in tokens."
)
example.update(tokens)
else:
example[Modalities.TEXT.name] = tokens
Expand Down
2 changes: 1 addition & 1 deletion tests/datasets/test_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def test_collate_example_list():
result = DefaultDataCollator()(
[img_class_example, img_text_pair, audio_text_pair, nested_example],
)
for key in expected_result:
for key in expected_result: # noqa: PLC0206
assert key in result
if isinstance(expected_result[key], torch.Tensor):
assert torch.equal(result[key], expected_result[key])
Expand Down
Loading