Skip to content

Commit

Permalink
clean up docs
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Jan 30, 2025
1 parent 564d8c5 commit b1430fe
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 28 deletions.
5 changes: 5 additions & 0 deletions docs/source/guides/data_loading.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
Data loading
============

.. note::
This guide is specific to text-based data, however the :class:`~olmo_core.train.Trainer` can be
used with other modalities as well by creating a custom data loader subclass of
:class:`~olmo_core.data.data_loader.DataLoaderBase`.

Using OLMo-core's builtin data loading
--------------------------------------

Expand Down
7 changes: 3 additions & 4 deletions src/olmo_core/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
Overview
--------
Prepare your data by writing token IDs to numpy arrays on disk, using the
For text-based data you should prepare your data by writing token IDs to numpy arrays on disk, using the
`Dolma toolkit <https://allenai.github.io/dolma/>`_ for example.
Configure and build your dataset using the :class:`~olmo_core.data.numpy_dataset.NumpyDatasetConfig`
Then configure and build your dataset using the :class:`~olmo_core.data.numpy_dataset.NumpyDatasetConfig`
builder, build your data loader with the :class:`~olmo_core.data.data_loader.NumpyDataLoaderConfig`
builder, then pass it to :meth:`TrainerConfig.build() <olmo_core.train.TrainerConfig.build>`.
builder, and pass it to :meth:`TrainerConfig.build() <olmo_core.train.TrainerConfig.build>`.
"""

from .collator import DataCollator, PaddingDirection
Expand Down
2 changes: 1 addition & 1 deletion src/olmo_core/data/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class PaddingDirection(StrEnum):
@dataclass
class DataCollator:
"""
The default data collator used by the :class:`~olmo_core.train.Trainer`.
The default data collator used by :class:`~olmo_core.data.data_loader.TextDataLoaderBase` subclasses.
"""

pad_token_id: int
Expand Down
31 changes: 11 additions & 20 deletions src/olmo_core/data/data_loader.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,6 @@
"""
Distributed, deterministic, stateful data loaders used by the :class:`~olmo_core.train.Trainer`.
Overview
--------
Construct a data loader from a :class:`~olmo_core.data.numpy_dataset.NumpyDatasetBase` instance
using :meth:`NumpyDataLoaderBase.wrap_numpy_dataset()`::
data_loader = NumpyDataLoaderBase.wrap_numpy_dataset(dataset, ...)
Then load batches for an epoch like this::
# Prepare for the epoch.
data_loader.reshuffle(epoch=1)
for batch in data_loader:
# process batch
pass
# Reset internal bookkeeping.
data_loader.reset()
"""

import logging
Expand Down Expand Up @@ -74,6 +54,17 @@ class DataLoaderBase(ABC):
(i.e. before calling :meth:`__iter__`) and you must call :meth:`reset()` *after* each
epoch (i.e. after the iterator returned from :meth:`__iter__` has been exhausted).
Failure to do so will result in incorrect data order.
For example::
# Prepare for the epoch.
data_loader.reshuffle(epoch=1)
for batch in data_loader:
# process batch
pass
# Reset internal bookkeeping.
data_loader.reset()
:param work_dir: The working directory. Should be shared among local ranks.
:param global_batch_size: The global batch size. The units for this depend on the data loader
Expand Down
6 changes: 3 additions & 3 deletions src/olmo_core/train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
if __name__ == "__main__":
prepare_training_environment()
try:
# Build model, optimizer, dataset...
# Build train module and data loader...
# Build trainer.
trainer = trainer_config.build(model, optim, dataset)
trainer = trainer_config.build(train_module, data_loader)
# Run the trainer.
trainer.fit()
Expand Down Expand Up @@ -87,7 +87,7 @@ def prepare_training_environment(
Internally this calls:
- :func:`~olmo_core.distributed.utils.init_distributed()`, which also calls :func:`torch.cuda.set_device()`
for backends that support CUDA.
for backends that support CUDA, otherwise :func:`torch.set_default_device()`.
- :func:`~olmo_core.utils.prepare_cli_environment()`
So there's no need to call those separately.
Expand Down

0 comments on commit b1430fe

Please sign in to comment.