Skip to content

Commit

Permalink
update default collator
Browse files Browse the repository at this point in the history
  • Loading branch information
fcogidi committed Aug 22, 2024
1 parent 45e0626 commit 6c672fd
Show file tree
Hide file tree
Showing 16 changed files with 329 additions and 261 deletions.
8 changes: 4 additions & 4 deletions mmlearn/conf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from lightning.pytorch.loggers.wandb import _WANDB_AVAILABLE
from omegaconf import II, MISSING, SI, DictConfig

from mmlearn.datasets.core.example import collate_example_list
from mmlearn.datasets.core.data_collator import DefaultDataCollator


def _get_default_ckpt_dir() -> Any:
Expand All @@ -35,8 +35,8 @@ def _get_default_ckpt_dir() -> Any:
populate_full_signature=True,
dataset=MISSING,
pin_memory=True,
collate_fn=collate_example_list,
hydra_convert="all",
collate_fn=DefaultDataCollator(),
hydra_convert="object",
)


Expand Down Expand Up @@ -152,7 +152,7 @@ class MMLearnConf:
metadata={"help": "Configuration for torch.jit.compile."},
)
hydra: HydraConf = HydraConf(
searchpath=["pkg://mmlearn/conf", "file://./configs"],
searchpath=["pkg://mmlearn/conf"],
run=RunDir(
dir=SI("./outputs/${experiment_name}/${now:%Y-%m-%d}/${now:%H-%M-%S}")
),
Expand Down
17 changes: 6 additions & 11 deletions mmlearn/datasets/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,21 @@
"""Modules for core dataloading functionality."""

from mmlearn.datasets.core.combined_dataset import CombinedDataset
from mmlearn.datasets.core.example import (
Example,
collate_example_list,
find_matching_indices,
)
from mmlearn.datasets.core.modalities import ModalityRegistry
from mmlearn.datasets.core.data_collator import DefaultDataCollator
from mmlearn.datasets.core.example import Example, find_matching_indices
from mmlearn.datasets.core.modalities import Modalities
from mmlearn.datasets.core.samplers import (
CombinedDatasetRatioSampler,
DistributedEvalSampler,
)


Modalities = ModalityRegistry()

__all__ = [
"CombinedDataset",
"Example",
"collate_example_list",
"find_matching_indices",
"CombinedDatasetRatioSampler",
"DefaultDataCollator",
"DistributedEvalSampler",
"Example",
"find_matching_indices",
"Modalities",
]
7 changes: 5 additions & 2 deletions mmlearn/datasets/core/combined_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,11 @@ def __getitem__(self, idx: int) -> Example:
"Expected dataset examples to be instances of `Example` "
f"but found {type(example)}",
)
example.dataset_index = dataset_idx
example.create_ids()

if not hasattr(example, "dataset_index"):
example.dataset_index = dataset_idx
if not hasattr(example, "example_ids"):
example.create_ids()

return example

Expand Down
132 changes: 132 additions & 0 deletions mmlearn/datasets/core/data_collator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""Data collators for batching examples."""

from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any, Callable, Optional, Union

from torch.utils.data import default_collate

from mmlearn.datasets.core.example import Example
from mmlearn.datasets.core.modalities import Modalities, Modality


@dataclass
class DefaultDataCollator:
"""Default data collator for batching examples.
This data collator will collate a list of `Example` objects into a batch.
It can also apply processing functions to specified keys in the batch before
returning it.
Parameters
----------
batch_processors : Optional[dict[str, Callable[[Any], Any]]], default=None
Dictionary of processing functions to apply to the batch before returning it.
The key is the name of the key in the batch, and the value is the processing
function to apply to the key. The processing function must take a single
argument and return a single value. If the processing function returns
a dictionary, it must contain the key that was processed in it (all the
other keys will also be included in the batch).
Raises
------
ValueError
If the batch processor for a key does not return a dictionary with the
key in it.
"""

batch_processors: Optional[dict[str, Callable[[Any], Any]]] = None

def __call__(self, examples: list[Example]) -> dict[str, Any]:
"""Collate a list of `Example` objects and apply processing functions."""
batch = collate_example_list(examples)

if self.batch_processors is not None:
for key, processor in self.batch_processors.items():
batch_key: Union[str, Modality] = key
if Modalities.has_modality(key):
batch_key = Modalities.get_modality(key)

if batch_key in batch:
batch_processed = processor(batch[batch_key])
if isinstance(batch_processed, Mapping):
if batch_key not in batch_processed:
raise ValueError(
f"Batch processor for '{key}' key must return a dictionary "
f"with '{batch_key}' in it."
)
batch.update(batch_processed)
else:
batch[batch_key] = batch_processed

return batch


def collate_example_list(examples: list[Example]) -> dict[str, Any]:
"""Collate a list of `Example` objects into a batch.
Parameters
----------
examples : list[Example]
List of examples to collate.
Returns
-------
dict[str, Any]
Dictionary of batched examples.
"""
return _collate_example_dict(_merge_examples(examples))


def _merge_examples(examples: list[Example]) -> dict[str, Any]:
"""Convert a list of `dataset.example.Example` objects into a dictionary.
This method will merge examples with the same key into a list.
Parameters
----------
examples : list[Example]
List of examples to convert.
Returns
-------
dict[str, Any]
Dictionary of converted examples.
"""
merged_examples: dict[str, Any] = {}
for example in examples:
for key in example:
if key in merged_examples:
merged_examples[key].append(example[key])
else:
merged_examples[key] = [example[key]]

for key in merged_examples:
if isinstance(merged_examples[key][0], Example):
merged_examples[key] = _merge_examples(merged_examples[key])

return merged_examples


def _collate_example_dict(examples: dict[str, Any]) -> dict[str, Any]:
"""Collate a dictionary of examples into a batch.
Parameters
----------
examples : dict[str, Any]
Dictionary of examples to collate.
Returns
-------
dict[str, Any]
Dictionary of collated examples.
"""
batch = {}
for k, v in examples.items():
if isinstance(v, dict):
batch[k] = _collate_example_dict(v)
else:
batch[k] = default_collate(v)

return batch
71 changes: 0 additions & 71 deletions mmlearn/datasets/core/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import torch
from lightning.fabric.utilities import rank_zero_warn
from torch.utils.data import default_collate


class Example(OrderedDict[Any, Any]):
Expand Down Expand Up @@ -91,76 +90,6 @@ def __setitem__(self, key: Hashable, value: Any) -> None:
super().__setitem__(key, value)


def _merge_examples(examples: list[Example]) -> dict[str, Any]:
"""Convert a list of `dataset.example.Example` objects into a dictionary.
This method will merge examples with the same key into a list.
Parameters
----------
examples : list[Example]
List of examples to convert.
Returns
-------
dict[str, Any]
Dictionary of converted examples.
"""
merged_examples: dict[str, Any] = {}
for example in examples:
for key in example:
if key in merged_examples:
merged_examples[key].append(example[key])
else:
merged_examples[key] = [example[key]]

for key in merged_examples:
if isinstance(merged_examples[key][0], Example):
merged_examples[key] = _merge_examples(merged_examples[key])

return merged_examples


def _collate_example_dict(examples: dict[str, Any]) -> dict[str, Any]:
"""Collate a dictionary of examples into a batch.
Parameters
----------
examples : dict[str, Any]
Dictionary of examples to collate.
Returns
-------
dict[str, Any]
Dictionary of collated examples.
"""
batch = {}
for k, v in examples.items():
if isinstance(v, dict):
batch[k] = _collate_example_dict(v)
else:
batch[k] = default_collate(v)

return batch


def collate_example_list(examples: list[Example]) -> dict[str, Any]:
"""Collate a list of `Example` objects into a batch.
Parameters
----------
examples : list[Example]
List of examples to collate.
Returns
-------
dict[str, Any]
Dictionary of batched examples.
"""
return _collate_example_dict(_merge_examples(examples))


def find_matching_indices(
first_example_ids: torch.Tensor,
second_example_ids: torch.Tensor,
Expand Down
3 changes: 3 additions & 0 deletions mmlearn/datasets/core/modalities.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,3 +288,6 @@ def _is_format_string(string: str) -> bool:
"""
pattern = r"\{.*?\}"
return bool(re.search(pattern, string))


Modalities = ModalityRegistry()
2 changes: 1 addition & 1 deletion mmlearn/datasets/processors/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __call__(
if isinstance(value, torch.Tensor):
batch_encoding[key] = torch.squeeze(value, 0)

# replace 'input_ids' with 'Modalities.TEXT' for consistency
# use 'Modalities.TEXT' key for input_ids for consistency
batch_encoding[Modalities.TEXT] = batch_encoding["input_ids"]
return dict(batch_encoding)

Expand Down
Loading

0 comments on commit 6c672fd

Please sign in to comment.