Skip to content

Commit

Permalink
Merge branch 'master' into feature/comet-logger-update
Browse files Browse the repository at this point in the history
  • Loading branch information
Lothiraldan authored Dec 12, 2024
2 parents dc048b1 + 110d621 commit 3c84f84
Show file tree
Hide file tree
Showing 36 changed files with 578 additions and 37 deletions.
15 changes: 15 additions & 0 deletions .actions/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,21 @@ def convert_version2nightly(ver_file: str = "src/version.info") -> None:


if __name__ == "__main__":
import sys

import jsonargparse
from jsonargparse import ArgumentParser

def patch_jsonargparse_python_3_12_8():
if sys.version_info < (3, 12, 8):
return

def _parse_known_args_patch(self: ArgumentParser, args: Any = None, namespace: Any = None) -> tuple[Any, Any]:
namespace, args = super(ArgumentParser, self)._parse_known_args(args, namespace, intermixed=False) # type: ignore
return namespace, args

setattr(ArgumentParser, "_parse_known_args", _parse_known_args_patch)

patch_jsonargparse_python_3_12_8() # Required until fix https://github.com/omni-us/jsonargparse/issues/641

jsonargparse.CLI(AssistantCLI, as_positional=False)
4 changes: 3 additions & 1 deletion .azure/gpu-benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ jobs:
pip list
displayName: "Image info & NVIDIA"
- bash: pip install -e .[dev] --find-links ${TORCH_URL}
- bash: |
pip install -e .[dev] --find-links ${TORCH_URL}
pip install setuptools==75.6.0
env:
FREEZE_REQUIREMENTS: "1"
displayName: "Install package"
Expand Down
1 change: 1 addition & 0 deletions .azure/gpu-tests-fabric.yml
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ jobs:
- bash: |
extra=$(python -c "print({'lightning': 'fabric-'}.get('$(PACKAGE_NAME)', ''))")
pip install -e ".[${extra}dev]" pytest-timeout -U --find-links="${TORCH_URL}" --find-links="${TORCHVISION_URL}"
pip install setuptools==75.6.0
displayName: "Install package & dependencies"
- bash: |
Expand Down
1 change: 1 addition & 0 deletions .azure/gpu-tests-pytorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ jobs:
- bash: |
extra=$(python -c "print({'lightning': 'pytorch-'}.get('$(PACKAGE_NAME)', ''))")
pip install -e ".[${extra}dev]" pytest-timeout -U --find-links="${TORCH_URL}" --find-links="${TORCHVISION_URL}"
pip install setuptools==75.6.0
displayName: "Install package & dependencies"
- bash: pip uninstall -y lightning
Expand Down
2 changes: 1 addition & 1 deletion _notebooks
3 changes: 2 additions & 1 deletion dockers/base-cuda/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ RUN \
add-apt-repository ppa:deadsnakes/ppa && \
apt-get install -y \
python${PYTHON_VERSION} \
python${PYTHON_VERSION}-distutils \
python${PYTHON_VERSION}-dev \
&& \
update-alternatives --install /usr/bin/python${PYTHON_VERSION%%.*} python${PYTHON_VERSION%%.*} /usr/bin/python${PYTHON_VERSION} 1 && \
Expand All @@ -79,6 +78,8 @@ RUN \
curl https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} && \
# Disable cache \
pip config set global.cache-dir false && \
# Install recent setuptools to obtain pkg_resources \
pip install setuptools==75.6.0 && \
# set particular PyTorch version \
pip install -q wget packaging && \
python -m wget https://raw.githubusercontent.com/Lightning-AI/utilities/main/scripts/adjust-torch-versions.py && \
Expand Down
2 changes: 1 addition & 1 deletion dockers/docs/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ RUN \
dvipng \
texlive-pictures \
python3 \
python3-distutils \
python3-setuptools \
python3-dev \
&& \
update-alternatives --install /usr/bin/python python /usr/bin/python3 1 && \
Expand Down
2 changes: 1 addition & 1 deletion dockers/release/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ RUN \
fi && \
# otherwise there is collision with folder name and pkg name on Pypi
cd pytorch-lightning && \
pip install setuptools && \
pip install setuptools==75.6.0 && \
PACKAGE_NAME=lightning pip install '.[extra,loggers,strategies]' --no-cache-dir && \
PACKAGE_NAME=pytorch pip install '.[extra,loggers,strategies]' --no-cache-dir && \
cd .. && \
Expand Down
24 changes: 23 additions & 1 deletion docs/source-pytorch/common/checkpointing_basic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ PyTorch Lightning checkpoints are fully usable in plain PyTorch.

----

.. important::

**Important Update: Deprecated Method**

Starting from PyTorch Lightning v1.0.0, the `resume_from_checkpoint` argument has been deprecated. To resume training from a checkpoint, use the `ckpt_path` argument in the `fit()` method.
Please update your code accordingly to avoid potential compatibility issues.

************************
Contents of a checkpoint
************************
Expand Down Expand Up @@ -197,16 +204,31 @@ You can disable checkpointing by passing:

----


*********************
Resume training state
*********************

If you don't just want to load weights, but instead restore the full training, do the following:

Correct usage:

.. code-block:: python
model = LitModel()
trainer = Trainer()
# automatically restores model, epoch, step, LR schedulers, etc...
trainer.fit(model, ckpt_path="some/path/to/my_checkpoint.ckpt")
trainer.fit(model, ckpt_path="path/to/your/checkpoint.ckpt")
.. warning::

The argument `resume_from_checkpoint` has been deprecated in versions of PyTorch Lightning >= 1.0.0.
To resume training from a checkpoint, use the `ckpt_path` argument in the `fit()` method instead.

Incorrect (deprecated) usage:

.. code-block:: python
trainer = Trainer(resume_from_checkpoint="path/to/your/checkpoint.ckpt")
trainer.fit(model)
8 changes: 8 additions & 0 deletions docs/source-pytorch/common/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
../data/data
../model/own_your_loop
../advanced/model_init
../common/tbptt


#############
Expand Down Expand Up @@ -202,6 +203,13 @@ How-to Guides
:col_css: col-md-4
:height: 180

.. displayitem::
:header: Truncated Back-Propagation Through Time
:description: Efficiently step through time when training recurrent models
:button_link: ../common/tbptt.html
:col_css: col-md-4
:height: 180

.. raw:: html

</div>
Expand Down
59 changes: 59 additions & 0 deletions docs/source-pytorch/common/tbptt.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
##############################################
Truncated Backpropagation Through Time (TBPTT)
##############################################

Truncated Backpropagation Through Time (TBPTT) performs backpropogation every k steps of
a much longer sequence. This is made possible by passing training batches
split along the time-dimensions into splits of size k to the
``training_step``. In order to keep the same forward propagation behavior, all
hidden states should be kept in-between each time-dimension split.


.. code-block:: python
import torch
import torch.optim as optim
import pytorch_lightning as pl
from pytorch_lightning import LightningModule
class LitModel(LightningModule):
def __init__(self):
super().__init__()
# 1. Switch to manual optimization
self.automatic_optimization = False
self.truncated_bptt_steps = 10
self.my_rnn = ParityModuleRNN() # Define RNN model using ParityModuleRNN
# 2. Remove the `hiddens` argument
def training_step(self, batch, batch_idx):
# 3. Split the batch in chunks along the time dimension
split_batches = split_batch(batch, self.truncated_bptt_steps)
batch_size = 10
hidden_dim = 20
hiddens = torch.zeros(1, batch_size, hidden_dim, device=self.device)
for split_batch in range(split_batches):
# 4. Perform the optimization in a loop
loss, hiddens = self.my_rnn(split_batch, hiddens)
self.backward(loss)
self.optimizer.step()
self.optimizer.zero_grad()
# 5. "Truncate"
hiddens = hiddens.detach()
# 6. Remove the return of `hiddens`
# Returning loss in manual optimization is not needed
return None
def configure_optimizers(self):
return optim.Adam(self.my_rnn.parameters(), lr=0.001)
if __name__ == "__main__":
model = LitModel()
trainer = pl.Trainer(max_epochs=5)
trainer.fit(model, train_dataloader) # Define your own dataloader
4 changes: 3 additions & 1 deletion docs/source-pytorch/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,9 @@ def _load_py_module(name: str, location: str) -> ModuleType:
("py:obj", "lightning.pytorch.utilities.memory.is_out_of_cpu_memory"),
("py:func", "lightning.pytorch.utilities.rank_zero.rank_zero_only"),
("py:class", "lightning.pytorch.utilities.types.LRSchedulerConfig"),
("py:class", "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig"),
("py:class", "lightning.pytorch.utilities.types.LRSchedulerConfigType"),
("py:class", "lightning.pytorch.utilities.types.OptimizerConfigType"),
("py:class", "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfigType"),
("py:class", "lightning_habana.pytorch.plugins.precision.HPUPrecisionPlugin"),
("py:class", "lightning_habana.pytorch.strategies.HPUDDPStrategy"),
("py:class", "lightning_habana.pytorch.strategies.HPUParallelStrategy"),
Expand Down
18 changes: 17 additions & 1 deletion examples/fabric/reinforcement_learning/rl/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import argparse
import math
import os
from distutils.util import strtobool
from typing import TYPE_CHECKING, Optional, Union

import gymnasium as gym
Expand All @@ -12,6 +11,23 @@
from rl.agent import PPOAgent, PPOLightningAgent


def strtobool(val):
"""Convert a string representation of truth to true (1) or false (0).
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'.
Raises ValueError if 'val' is anything else.
Note: taken from distutils after its deprecation.
"""
val = val.lower()
if val in ("y", "yes", "t", "true", "on", "1"):
return 1
if val in ("n", "no", "f", "false", "off", "0"):
return 0
raise ValueError(f"invalid truth value {val!r}")


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--exp-name", type=str, default="default", help="the name of this experiment")
Expand Down
16 changes: 10 additions & 6 deletions src/lightning/fabric/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,15 +220,19 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None)
@override
@rank_zero_only
def log_hyperparams(
self, params: Union[dict[str, Any], Namespace], metrics: Optional[dict[str, Any]] = None
self,
params: Union[dict[str, Any], Namespace],
metrics: Optional[dict[str, Any]] = None,
step: Optional[int] = None,
) -> None:
"""Record hyperparameters. TensorBoard logs with and without saved hyperparameters are incompatible, the
hyperparameters are then not displayed in the TensorBoard. Please delete or move the previously saved logs to
display the new ones with hyperparameters.
Args:
params: a dictionary-like container with the hyperparameters
params: A dictionary-like container with the hyperparameters
metrics: Dictionary with metric names as keys and measured quantities as values
step: Optional global step number for the logged metrics
"""
params = _convert_params(params)
Expand All @@ -244,7 +248,7 @@ def log_hyperparams(
metrics = {"hp_metric": metrics}

if metrics:
self.log_metrics(metrics, 0)
self.log_metrics(metrics, step)

if _TENSORBOARD_AVAILABLE:
from torch.utils.tensorboard.summary import hparams
Expand All @@ -253,9 +257,9 @@ def log_hyperparams(

exp, ssi, sei = hparams(params, metrics)
writer = self.experiment._get_file_writer()
writer.add_summary(exp)
writer.add_summary(ssi)
writer.add_summary(sei)
writer.add_summary(exp, step)
writer.add_summary(ssi, step)
writer.add_summary(sei, step)

@override
@rank_zero_only
Expand Down
6 changes: 6 additions & 0 deletions src/lightning/fabric/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradSca
}
self._desired_input_dtype = precision_to_type[self.precision]

@override
def convert_module(self, module: Module) -> Module:
if "true" in self.precision:
return module.to(dtype=self._desired_input_dtype)
return module

@property
def mixed_precision_config(self) -> "TorchMixedPrecision":
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision
Expand Down
8 changes: 7 additions & 1 deletion src/lightning/fabric/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import platform
from collections.abc import Mapping
from contextlib import AbstractContextManager, ExitStack
from datetime import timedelta
from itertools import chain
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
Expand All @@ -29,6 +30,7 @@
from typing_extensions import override

from lightning.fabric.accelerators import Accelerator, CUDAAccelerator
from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
from lightning.fabric.plugins.precision import Precision
from lightning.fabric.strategies.ddp import DDPStrategy
Expand Down Expand Up @@ -97,6 +99,7 @@ def __init__(
load_full_weights: bool = False,
precision: Optional[Precision] = None,
process_group_backend: Optional[str] = None,
timeout: Optional[timedelta] = default_pg_timeout,
) -> None:
"""Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
billion parameter models. `For more information: https://pytorch-
Expand Down Expand Up @@ -241,6 +244,7 @@ def __init__(
process_group_backend=process_group_backend,
)
self._backward_sync_control = None # DeepSpeed handles gradient accumulation internally
self._timeout: Optional[timedelta] = timeout

self.config = self._load_config(config)
if self.config is None:
Expand Down Expand Up @@ -648,7 +652,9 @@ def _init_deepspeed_distributed(self) -> None:
f"MEMBER: {self.global_rank + 1}/{self.world_size}"
)
self._process_group_backend = self._get_process_group_backend()
deepspeed.init_distributed(self._process_group_backend, distributed_port=self.cluster_environment.main_port)
deepspeed.init_distributed(
self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout
)

def _set_node_environment_variables(self) -> None:
assert self.cluster_environment is not None
Expand Down
16 changes: 16 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,21 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- CometML logger was updated to support the recent Comet SDK ([#20275](https://github.com/Lightning-AI/pytorch-lightning/pull/20275))


## [unreleased] - YYYY-MM-DD

### Added

### Changed

- Merging of hparams when logging now ignores parameter names that begin with underscore `_` ([#20221](https://github.com/Lightning-AI/pytorch-lightning/pull/20221))

### Removed

### Fixed

- Fix LightningCLI failing when both module and data module save hyperparameters due to conflicting internal `_class_path` parameter ([#20221](https://github.com/Lightning-AI/pytorch-lightning/pull/20221))


## [2.4.0] - 2024-08-06

### Added
Expand Down Expand Up @@ -41,6 +56,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `_LoggerConnector`'s `_ResultMetric` to move all registered keys to the device of the logged value if needed ([#19814](https://github.com/Lightning-AI/pytorch-lightning/issues/19814))
- Fixed `_optimizer_to_device` logic for special 'step' key in optimizer state causing performance regression ([#20019](https://github.com/Lightning-AI/lightning/pull/20019))
- Fixed parameter counts in `ModelSummary` when model has distributed parameters (DTensor) ([#20163](https://github.com/Lightning-AI/pytorch-lightning/pull/20163))
- Fixed PyTorch Lightning FSDP takes more memory than PyTorch FSDP ([#20323](https://github.com/Lightning-AI/pytorch-lightning/pull/20323))


## [2.3.0] - 2024-06-13
Expand Down
Loading

0 comments on commit 3c84f84

Please sign in to comment.