diff --git a/.actions/assistant.py b/.actions/assistant.py index 15a20e63c61dc..6bb0bc201ed05 100644 --- a/.actions/assistant.py +++ b/.actions/assistant.py @@ -18,10 +18,11 @@ import shutil import tempfile import urllib.request +from collections.abc import Iterable, Iterator, Sequence from itertools import chain from os.path import dirname, isfile from pathlib import Path -from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple +from typing import Any, Optional from packaging.requirements import Requirement from packaging.version import Version @@ -127,7 +128,7 @@ def _parse_requirements(lines: Iterable[str]) -> Iterator[_RequirementWithCommen pip_argument = None -def load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: str = "all") -> List[str]: +def load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: str = "all") -> list[str]: """Loading requirements from a file. >>> path_req = os.path.join(_PROJECT_ROOT, "requirements") @@ -222,7 +223,7 @@ def _load_aggregate_requirements(req_dir: str = "requirements", freeze_requireme fp.writelines([ln + os.linesep for ln in requires] + [os.linesep]) -def _retrieve_files(directory: str, *ext: str) -> List[str]: +def _retrieve_files(directory: str, *ext: str) -> list[str]: all_files = [] for root, _, files in os.walk(directory): for fname in files: @@ -232,7 +233,7 @@ def _retrieve_files(directory: str, *ext: str) -> List[str]: return all_files -def _replace_imports(lines: List[str], mapping: List[Tuple[str, str]], lightning_by: str = "") -> List[str]: +def _replace_imports(lines: list[str], mapping: list[tuple[str, str]], lightning_by: str = "") -> list[str]: """Replace imports of standalone package to lightning. >>> lns = [ @@ -320,7 +321,7 @@ def copy_replace_imports( fo.writelines(lines) -def create_mirror_package(source_dir: str, package_mapping: Dict[str, str]) -> None: +def create_mirror_package(source_dir: str, package_mapping: dict[str, str]) -> None: """Create a mirror package with adjusted imports.""" # replace imports and copy the code mapping = package_mapping.copy() @@ -482,6 +483,21 @@ def convert_version2nightly(ver_file: str = "src/version.info") -> None: if __name__ == "__main__": + import sys + import jsonargparse + from jsonargparse import ArgumentParser + + def patch_jsonargparse_python_3_12_8(): + if sys.version_info < (3, 12, 8): + return + + def _parse_known_args_patch(self: ArgumentParser, args: Any = None, namespace: Any = None) -> tuple[Any, Any]: + namespace, args = super(ArgumentParser, self)._parse_known_args(args, namespace, intermixed=False) # type: ignore + return namespace, args + + setattr(ArgumentParser, "_parse_known_args", _parse_known_args_patch) + + patch_jsonargparse_python_3_12_8() # Required until fix https://github.com/omni-us/jsonargparse/issues/641 jsonargparse.CLI(AssistantCLI, as_positional=False) diff --git a/.azure/gpu-benchmarks.yml b/.azure/gpu-benchmarks.yml index 111589945e048..6825850ae01bd 100644 --- a/.azure/gpu-benchmarks.yml +++ b/.azure/gpu-benchmarks.yml @@ -46,7 +46,7 @@ jobs: variables: DEVICES: $( python -c 'print("$(Agent.Name)".split("_")[-1])' ) container: - image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.4-cuda12.1.0" + image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.5-cuda12.1.0" options: "--gpus=all --shm-size=32g" strategy: matrix: @@ -75,7 +75,9 @@ jobs: pip list displayName: "Image info & NVIDIA" - - bash: pip install -e .[dev] --find-links ${TORCH_URL} + - bash: | + pip install -e .[dev] --find-links ${TORCH_URL} + pip install setuptools==75.6.0 env: FREEZE_REQUIREMENTS: "1" displayName: "Install package" diff --git a/.azure/gpu-tests-fabric.yml b/.azure/gpu-tests-fabric.yml index e63641b8ecc7d..9d7514f285bda 100644 --- a/.azure/gpu-tests-fabric.yml +++ b/.azure/gpu-tests-fabric.yml @@ -60,7 +60,7 @@ jobs: image: "pytorchlightning/pytorch_lightning:base-cuda-py3.11-torch2.3-cuda12.1.0" PACKAGE_NAME: "fabric" "Lightning | latest": - image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.4-cuda12.1.0" + image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.5-cuda12.1.0" PACKAGE_NAME: "lightning" workspace: clean: all @@ -107,6 +107,7 @@ jobs: - bash: | extra=$(python -c "print({'lightning': 'fabric-'}.get('$(PACKAGE_NAME)', ''))") pip install -e ".[${extra}dev]" pytest-timeout -U --find-links="${TORCH_URL}" --find-links="${TORCHVISION_URL}" + pip install setuptools==75.6.0 displayName: "Install package & dependencies" - bash: | @@ -134,13 +135,13 @@ jobs: condition: and(succeeded(), eq(variables['PACKAGE_NAME'], 'fabric')) displayName: "Adjust tests & examples" - - bash: python -m coverage run --source ${COVERAGE_SOURCE} -m pytest . -v --durations=50 - workingDirectory: tests/tests_fabric/ + - bash: python -m coverage run --source ${COVERAGE_SOURCE} -m pytest tests_fabric/ -v --durations=50 + workingDirectory: tests/ displayName: "Testing: fabric standard" timeoutInMinutes: "10" - - bash: bash ../run_standalone_tests.sh "." - workingDirectory: tests/tests_fabric/ + - bash: bash ./run_standalone_tests.sh "tests_fabric" + workingDirectory: tests/ env: PL_STANDALONE_TESTS_SOURCE: $(COVERAGE_SOURCE) displayName: "Testing: fabric standalone" @@ -157,7 +158,7 @@ jobs: ./codecov --token=$(CODECOV_TOKEN) --commit=$(Build.SourceVersion) \ --flags=gpu,pytest,${COVERAGE_SOURCE} --name="GPU-coverage" --env=linux,azure ls -l - workingDirectory: tests/tests_fabric/ + workingDirectory: tests/ displayName: "Statistics" - script: | diff --git a/.azure/gpu-tests-pytorch.yml b/.azure/gpu-tests-pytorch.yml index 56c0ace195ed0..b48961c7f3a26 100644 --- a/.azure/gpu-tests-pytorch.yml +++ b/.azure/gpu-tests-pytorch.yml @@ -53,7 +53,7 @@ jobs: image: "pytorchlightning/pytorch_lightning:base-cuda-py3.11-torch2.3-cuda12.1.0" PACKAGE_NAME: "pytorch" "Lightning | latest": - image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.4-cuda12.1.0" + image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.5-cuda12.1.0" PACKAGE_NAME: "lightning" pool: lit-rtx-3090 variables: @@ -111,6 +111,7 @@ jobs: - bash: | extra=$(python -c "print({'lightning': 'pytorch-'}.get('$(PACKAGE_NAME)', ''))") pip install -e ".[${extra}dev]" pytest-timeout -U --find-links="${TORCH_URL}" --find-links="${TORCHVISION_URL}" + pip install setuptools==75.6.0 displayName: "Install package & dependencies" - bash: pip uninstall -y lightning @@ -155,13 +156,13 @@ jobs: ls -l checkpoints/ displayName: "Get legacy checkpoints" - - bash: python -m coverage run --source ${COVERAGE_SOURCE} -m pytest -v --durations=50 - workingDirectory: tests/tests_pytorch + - bash: python -m coverage run --source ${COVERAGE_SOURCE} -m pytest tests_pytorch/ -v --durations=50 + workingDirectory: tests/ displayName: "Testing: PyTorch standard" timeoutInMinutes: "35" - - bash: bash ../run_standalone_tests.sh "." - workingDirectory: tests/tests_pytorch + - bash: bash ./run_standalone_tests.sh "tests_pytorch" + workingDirectory: tests/ env: PL_USE_MOCKED_MNIST: "1" PL_STANDALONE_TESTS_SOURCE: $(COVERAGE_SOURCE) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index cdc2b63b2379d..e0b52f31804d6 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -5,39 +5,16 @@ # the repo. Unless a later match takes precedence, # @global-owner1 and @global-owner2 will be requested for # review when someone opens a pull request. -* @lantiga @borda @tchaton @awaelchli @justusschock - -# CI/CD and configs -/.actions/ @borda @ethanwharris @justusschock -/.github/ @borda @ethanwharris @justusschock -/.azure/ @borda @ethanwharris @justusschock -/dockers/ @borda @ethanwharris @justusschock -*.yml @borda @ethanwharris @justusschock +* @lantiga @borda @tchaton @justusschock @ethanwharris # Docs -/docs/ @lantiga @borda @awaelchli -/docs/*/conf.py @borda @awaelchli /.github/*.md @williamfalcon @lantiga @borda -/.github/ISSUE_TEMPLATE/ @borda @tchaton @awaelchli -/docs/source-fabric/conf.py @borda @awaelchli -/docs/source-fabric/index.rst @awaelchli @lantiga -/docs/source-pytorch/conf.py @borda @awaelchli +/docs/source-fabric/index.rst @williamfalcon @lantiga /docs/source-pytorch/index.rst @williamfalcon @lantiga /docs/source-pytorch/levels @williamfalcon @lantiga -# PyTorch Lightning -/src/lightning/pytorch @lantiga @borda @tchaton @awaelchli @justusschock - -# Lightning Data -/src/lightning/data/ @tchaton @lantiga - -# Lightning Fabric -/src/lightning/fabric @lantiga @borda @tchaton @awaelchli @justusschock - /.github/CODEOWNERS @williamfalcon /SECURITY.md @williamfalcon @lantiga /README.md @williamfalcon @lantiga -/setup.py @williamfalcon @borda -/src/pytorch_lightning/__about__.py @williamfalcon @borda -/src/lightning_fabric/__about__.py @williamfalcon @borda @awaelchli -/src/*/__setup__.py @borda @justusschock +/src/pytorch_lightning/__about__.py @williamfalcon @lantiga @borda +/src/lightning_fabric/__about__.py @williamfalcon @lantiga @borda diff --git a/.github/checkgroup.yml b/.github/checkgroup.yml index 20875df42c5a8..6c7c6aa5f7c1b 100644 --- a/.github/checkgroup.yml +++ b/.github/checkgroup.yml @@ -19,27 +19,30 @@ subprojects: - "!*.md" - "!**/*.md" checks: - - "pl-cpu (macOS-13, lightning, 3.9, 2.1, oldest)" + - "pl-cpu (macOS-14, lightning, 3.9, 2.1, oldest)" - "pl-cpu (macOS-14, lightning, 3.10, 2.1)" - - "pl-cpu (macOS-14, lightning, 3.11, 2.2)" + - "pl-cpu (macOS-14, lightning, 3.11, 2.2.2)" - "pl-cpu (macOS-14, lightning, 3.11, 2.3)" - - "pl-cpu (macOS-14, lightning, 3.12, 2.4)" + - "pl-cpu (macOS-14, lightning, 3.12, 2.4.1)" + - "pl-cpu (macOS-14, lightning, 3.12, 2.5.1)" - "pl-cpu (ubuntu-20.04, lightning, 3.9, 2.1, oldest)" - "pl-cpu (ubuntu-20.04, lightning, 3.10, 2.1)" - - "pl-cpu (ubuntu-20.04, lightning, 3.11, 2.2)" + - "pl-cpu (ubuntu-20.04, lightning, 3.11, 2.2.2)" - "pl-cpu (ubuntu-20.04, lightning, 3.11, 2.3)" - - "pl-cpu (ubuntu-20.04, lightning, 3.12, 2.4)" + - "pl-cpu (ubuntu-22.04, lightning, 3.12, 2.4.1)" + - "pl-cpu (ubuntu-22.04, lightning, 3.12, 2.5.1)" - "pl-cpu (windows-2022, lightning, 3.9, 2.1, oldest)" - "pl-cpu (windows-2022, lightning, 3.10, 2.1)" - - "pl-cpu (windows-2022, lightning, 3.11, 2.2)" + - "pl-cpu (windows-2022, lightning, 3.11, 2.2.2)" - "pl-cpu (windows-2022, lightning, 3.11, 2.3)" - - "pl-cpu (windows-2022, lightning, 3.12, 2.4)" + - "pl-cpu (windows-2022, lightning, 3.12, 2.4.1)" + - "pl-cpu (windows-2022, lightning, 3.12, 2.5.1)" - "pl-cpu (macOS-14, pytorch, 3.9, 2.1)" - "pl-cpu (ubuntu-20.04, pytorch, 3.9, 2.1)" - "pl-cpu (windows-2022, pytorch, 3.9, 2.1)" - - "pl-cpu (macOS-12, pytorch, 3.10, 2.1)" - - "pl-cpu (ubuntu-22.04, pytorch, 3.10, 2.1)" - - "pl-cpu (windows-2022, pytorch, 3.10, 2.1)" + - "pl-cpu (macOS-14, pytorch, 3.12, 2.5.1)" + - "pl-cpu (ubuntu-22.04, pytorch, 3.12, 2.5.1)" + - "pl-cpu (windows-2022, pytorch, 3.12, 2.5.1)" - id: "pytorch_lightning: Azure GPU" paths: @@ -86,14 +89,15 @@ subprojects: checks: - "lightning.Benchmarks" - - id: "pytorch-lightning: TPU workflow" - paths: - # tpu CI availability is very limited, so we only require tpu tests - # to pass when their configurations are modified - - ".github/workflows/tpu-tests.yml" - - "tests/tests_pytorch/run_tpu_tests.sh" - checks: - - "test-on-tpus (pytorch, pjrt, v4-8)" + # Temporarily disabled + # - id: "pytorch-lightning: TPU workflow" + # paths: + # # tpu CI availability is very limited, so we only require tpu tests + # # to pass when their configurations are modified + # - ".github/workflows/tpu-tests.yml" + # - "tests/tests_pytorch/run_tpu_tests.sh" + # checks: + # - "test-on-tpus (pytorch, pjrt, v4-8)" - id: "fabric: Docs" paths: @@ -141,15 +145,17 @@ subprojects: - "!*.md" - "!**/*.md" checks: - - "build-cuda (3.11, 2.1, 12.1.0)" - - "build-cuda (3.11, 2.2, 12.1.0)" - - "build-cuda (3.11, 2.3, 12.1.0)" - - "build-cuda (3.12, 2.4, 12.1.0)" + - "build-cuda (3.10, 2.1.2, 12.1.0)" + - "build-cuda (3.11, 2.2.2, 12.1.0)" + - "build-cuda (3.11, 2.3.1, 12.1.0)" + - "build-cuda (3.11, 2.4.1, 12.1.0)" + - "build-cuda (3.12, 2.5.1, 12.1.0)" #- "build-NGC" - - "build-pl (3.11, 2.1, 12.1.0)" + - "build-pl (3.10, 2.1, 12.1.0)" - "build-pl (3.11, 2.2, 12.1.0)" - "build-pl (3.11, 2.3, 12.1.0)" - - "build-pl (3.12, 2.4, 12.1.0)" + - "build-pl (3.11, 2.4, 12.1.0)" + - "build-pl (3.12, 2.5, 12.1.0)" # SECTION: lightning_fabric @@ -166,27 +172,30 @@ subprojects: - "!*.md" - "!**/*.md" checks: - - "fabric-cpu (macOS-13, lightning, 3.9, 2.1, oldest)" + - "fabric-cpu (macOS-14, lightning, 3.9, 2.1, oldest)" - "fabric-cpu (macOS-14, lightning, 3.10, 2.1)" - - "fabric-cpu (macOS-14, lightning, 3.11, 2.2)" + - "fabric-cpu (macOS-14, lightning, 3.11, 2.2.2)" - "fabric-cpu (macOS-14, lightning, 3.11, 2.3)" - - "fabric-cpu (macOS-14, lightning, 3.12, 2.4)" + - "fabric-cpu (macOS-14, lightning, 3.12, 2.4.1)" + - "fabric-cpu (macOS-14, lightning, 3.12, 2.5.1)" - "fabric-cpu (ubuntu-20.04, lightning, 3.9, 2.1, oldest)" - "fabric-cpu (ubuntu-20.04, lightning, 3.10, 2.1)" - - "fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.2)" + - "fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.2.2)" - "fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.3)" - - "fabric-cpu (ubuntu-20.04, lightning, 3.12, 2.4)" + - "fabric-cpu (ubuntu-22.04, lightning, 3.12, 2.4.1)" + - "fabric-cpu (ubuntu-22.04, lightning, 3.12, 2.5.1)" - "fabric-cpu (windows-2022, lightning, 3.9, 2.1, oldest)" - "fabric-cpu (windows-2022, lightning, 3.10, 2.1)" - - "fabric-cpu (windows-2022, lightning, 3.11, 2.2)" + - "fabric-cpu (windows-2022, lightning, 3.11, 2.2.2)" - "fabric-cpu (windows-2022, lightning, 3.11, 2.3)" - - "fabric-cpu (windows-2022, lightning, 3.12, 2.4)" + - "fabric-cpu (windows-2022, lightning, 3.12, 2.4.1)" + - "fabric-cpu (windows-2022, lightning, 3.12, 2.5.1)" - "fabric-cpu (macOS-14, fabric, 3.9, 2.1)" - "fabric-cpu (ubuntu-20.04, fabric, 3.9, 2.1)" - "fabric-cpu (windows-2022, fabric, 3.9, 2.1)" - - "fabric-cpu (macOS-12, fabric, 3.10, 2.1)" - - "fabric-cpu (ubuntu-22.04, fabric, 3.10, 2.1)" - - "fabric-cpu (windows-2022, fabric, 3.10, 2.1)" + - "fabric-cpu (macOS-14, fabric, 3.12, 2.5.1)" + - "fabric-cpu (ubuntu-22.04, fabric, 3.12, 2.5.1)" + - "fabric-cpu (windows-2022, fabric, 3.12, 2.5.1)" - id: "lightning_fabric: Azure GPU" paths: @@ -258,14 +267,14 @@ subprojects: - "install-pkg (ubuntu-22.04, lightning, 3.11)" - "install-pkg (ubuntu-22.04, notset, 3.9)" - "install-pkg (ubuntu-22.04, notset, 3.11)" - - "install-pkg (macOS-12, fabric, 3.9)" - - "install-pkg (macOS-12, fabric, 3.11)" - - "install-pkg (macOS-12, pytorch, 3.9)" - - "install-pkg (macOS-12, pytorch, 3.11)" - - "install-pkg (macOS-12, lightning, 3.9)" - - "install-pkg (macOS-12, lightning, 3.11)" - - "install-pkg (macOS-12, notset, 3.9)" - - "install-pkg (macOS-12, notset, 3.11)" + - "install-pkg (macOS-14, fabric, 3.9)" + - "install-pkg (macOS-14, fabric, 3.11)" + - "install-pkg (macOS-14, pytorch, 3.9)" + - "install-pkg (macOS-14, pytorch, 3.11)" + - "install-pkg (macOS-14, lightning, 3.9)" + - "install-pkg (macOS-14, lightning, 3.11)" + - "install-pkg (macOS-14, notset, 3.9)" + - "install-pkg (macOS-14, notset, 3.11)" - "install-pkg (windows-2022, fabric, 3.9)" - "install-pkg (windows-2022, fabric, 3.11)" - "install-pkg (windows-2022, pytorch, 3.9)" diff --git a/.github/workflows/_build-packages.yml b/.github/workflows/_build-packages.yml index e0262ac63b685..cf6ed5379801b 100644 --- a/.github/workflows/_build-packages.yml +++ b/.github/workflows/_build-packages.yml @@ -19,37 +19,16 @@ defaults: shell: bash jobs: - init: - runs-on: ubuntu-22.04 - steps: - - uses: actions/checkout@v4 - - run: | - mkdir dist && touch dist/.placeholder - - name: Keep artifact - id: keep-artifact - run: python -c "print('DAYS=' + str(5 if '${{ github.event_name }}'.startswith('pull_request') else 0))" >> $GITHUB_OUTPUT - - uses: actions/upload-artifact@v3 - with: - name: ${{ inputs.artifact-name }} - path: dist - retention-days: ${{ steps.keep-artifact.outputs.DAYS }} - build-packages: - needs: init runs-on: ubuntu-22.04 strategy: - max-parallel: 1 # run sequential to prevent download/upload collisions matrix: pkg-name: ${{ fromJSON(inputs.pkg-names) }} steps: - uses: actions/checkout@v4 - - uses: actions/download-artifact@v3 - with: - name: ${{ inputs.artifact-name }} - path: pypi - uses: actions/setup-python@v5 with: - python-version: 3.9 + python-version: "3.x" - run: python -c "print('NB_DIRS=' + str(2 if '${{ matrix.pkg-name }}' == 'pytorch' else 1))" >> $GITHUB_ENV - name: Build & check package @@ -59,10 +38,33 @@ jobs: nb-dirs: ${{ env.NB_DIRS }} - run: | - mkdir pypi/${{ matrix.pkg-name }} + mkdir -p pypi/${{ matrix.pkg-name }} cp dist/* pypi/${{ matrix.pkg-name }}/ - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 + with: + name: ${{ inputs.artifact-name }}-${{ matrix.pkg-name }} + path: pypi + retention-days: 1 + + merge-artifacts: + needs: build-packages + runs-on: ubuntu-22.04 + steps: + - uses: actions/download-artifact@v4 + with: # download all build artifacts + pattern: ${{ inputs.artifact-name }}-* + merge-multiple: true + path: pypi + - run: | + sudo apt-get install -y tree + tree pypi + + - name: Keep artifact + run: python -c "print('DAYS=' + str(5 if '${{ github.event_name }}'.startswith('pull_request') else 0))" >> $GITHUB_ENV + - uses: actions/upload-artifact@v4 with: name: ${{ inputs.artifact-name }} path: pypi + retention-days: ${{ env.DAYS }} + if-no-files-found: error diff --git a/.github/workflows/_legacy-checkpoints.yml b/.github/workflows/_legacy-checkpoints.yml index 16072112b80a9..0161ab57bca52 100644 --- a/.github/workflows/_legacy-checkpoints.yml +++ b/.github/workflows/_legacy-checkpoints.yml @@ -60,7 +60,7 @@ jobs: - uses: actions/setup-python@v5 with: # Python version here needs to be supported by all PL versions listed in back-compatible-versions.txt. - python-version: 3.8 + python-version: "3.9" - name: Install PL from source env: @@ -104,11 +104,12 @@ jobs: python -c "print('AWS_RUN=' + str('' if '${{inputs.push_to_s3}}' == 'true' else '--dryrun'))" >> $GITHUB_ENV - name: Upload checkpoints to GitHub Actions artifact - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: checkpoints-${{ github.sha }} path: ${{ env.LEGACY_FOLDER }}/checkpoints/ retention-days: ${{ env.KEEP_DAYS }} + include-hidden-files: true - run: pip install -r requirements/ci.txt - name: Upload checkpoints to S3 @@ -138,7 +139,7 @@ jobs: run: echo ${PL_VERSION} >> back-compatible-versions.txt - name: Create Pull Request - uses: peter-evans/create-pull-request@v6 + uses: peter-evans/create-pull-request@v7 with: title: Adding test for legacy checkpoint created with ${{ env.PL_VERSION }} committer: GitHub diff --git a/.github/workflows/call-clear-cache.yml b/.github/workflows/call-clear-cache.yml index f1f0404299568..1dddbe8f72bb0 100644 --- a/.github/workflows/call-clear-cache.yml +++ b/.github/workflows/call-clear-cache.yml @@ -23,18 +23,18 @@ on: jobs: cron-clear: if: github.event_name == 'schedule' || github.event_name == 'pull_request' - uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.11.6 + uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.11.9 with: - scripts-ref: v0.11.6 + scripts-ref: v0.11.8 dry-run: ${{ github.event_name == 'pull_request' }} pattern: "latest|docs" age-days: 7 direct-clear: if: github.event_name == 'workflow_dispatch' || github.event_name == 'pull_request' - uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.11.6 + uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.11.9 with: - scripts-ref: v0.11.6 + scripts-ref: v0.11.8 dry-run: ${{ github.event_name == 'pull_request' }} pattern: ${{ inputs.pattern || 'pypi_wheels' }} # setting str in case of PR / debugging age-days: ${{ fromJSON(inputs.age-days) || 0 }} # setting 0 in case of PR / debugging diff --git a/.github/workflows/ci-check-md-links.yml b/.github/workflows/ci-check-md-links.yml index d60d4f1cfa322..d0dc889230112 100644 --- a/.github/workflows/ci-check-md-links.yml +++ b/.github/workflows/ci-check-md-links.yml @@ -14,7 +14,7 @@ on: jobs: check-md-links: - uses: Lightning-AI/utilities/.github/workflows/check-md-links.yml@v0.11.6 + uses: Lightning-AI/utilities/.github/workflows/check-md-links.yml@v0.11.9 with: config-file: ".github/markdown-links-config.json" base-branch: "master" diff --git a/.github/workflows/ci-pkg-install.yml b/.github/workflows/ci-pkg-install.yml index d22a8d3ace1e2..61055c9b5ac3d 100644 --- a/.github/workflows/ci-pkg-install.yml +++ b/.github/workflows/ci-pkg-install.yml @@ -42,7 +42,7 @@ jobs: strategy: fail-fast: false matrix: - os: ["ubuntu-22.04", "macOS-12", "windows-2022"] + os: ["ubuntu-22.04", "macOS-14", "windows-2022"] pkg-name: ["fabric", "pytorch", "lightning", "notset"] python-version: ["3.9", "3.11"] steps: @@ -50,7 +50,7 @@ jobs: - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4 with: name: dist-packages-${{ github.sha }} path: dist diff --git a/.github/workflows/ci-schema.yml b/.github/workflows/ci-schema.yml index 632366a211177..32cd82f12784b 100644 --- a/.github/workflows/ci-schema.yml +++ b/.github/workflows/ci-schema.yml @@ -8,7 +8,7 @@ on: jobs: check: - uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.11.6 + uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.11.9 with: # skip azure due to the wrong schema file by MSFT # https://github.com/Lightning-AI/lightning-flash/pull/1455#issuecomment-1244793607 diff --git a/.github/workflows/ci-tests-fabric.yml b/.github/workflows/ci-tests-fabric.yml index 06616650deb9c..c2fda73dcf1f4 100644 --- a/.github/workflows/ci-tests-fabric.yml +++ b/.github/workflows/ci-tests-fabric.yml @@ -43,21 +43,24 @@ jobs: - { os: "macOS-14", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" } - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" } - { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" } - - { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" } - - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" } - - { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" } + - { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" } + - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" } + - { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" } - { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" } - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" } - { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" } - - { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4" } - - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4" } - - { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4" } + - { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" } + - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" } + - { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" } + - { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" } + - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" } + - { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" } # only run PyTorch latest with Python latest, use Fabric scope to limit dependency issues - - { os: "macOS-12", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.1" } - - { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.1" } - - { os: "windows-2022", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.1" } + - { os: "macOS-14", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.5.1" } + - { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.5.1" } + - { os: "windows-2022", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.5.1" } # "oldest" versions tests, only on minimum Python - - { os: "macOS-13", pkg-name: "lightning", python-version: "3.9", pytorch-version: "2.1", requires: "oldest" } + - { os: "macOS-14", pkg-name: "lightning", python-version: "3.9", pytorch-version: "2.1", requires: "oldest" } - { os: "ubuntu-20.04", pkg-name: "lightning", @@ -98,7 +101,10 @@ jobs: - name: Set min. dependencies if: ${{ matrix.requires == 'oldest' }} - run: python .actions/assistant.py replace_oldest_ver + run: | + python .actions/assistant.py replace_oldest_ver + pip install "cython<3.0" wheel + pip install "pyyaml==5.4" --no-build-isolation - name: Adjust PyTorch versions in requirements files if: ${{ matrix.requires != 'oldest' }} @@ -171,7 +177,7 @@ jobs: coverage xml - name: Upload coverage to Codecov - uses: codecov/codecov-action@v4 + uses: codecov/codecov-action@v5 # see: https://github.com/actions/toolkit/issues/399 continue-on-error: true with: diff --git a/.github/workflows/ci-tests-pytorch.yml b/.github/workflows/ci-tests-pytorch.yml index 4de22a24f36e6..ba8519ea8ed8a 100644 --- a/.github/workflows/ci-tests-pytorch.yml +++ b/.github/workflows/ci-tests-pytorch.yml @@ -47,21 +47,24 @@ jobs: - { os: "macOS-14", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" } - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" } - { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" } - - { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" } - - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" } - - { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2" } + - { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" } + - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" } + - { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" } - { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" } - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" } - { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" } - - { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4" } - - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4" } - - { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4" } + - { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" } + - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" } + - { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.4.1" } + - { os: "macOS-14", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" } + - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" } + - { os: "windows-2022", pkg-name: "lightning", python-version: "3.12", pytorch-version: "2.5.1" } # only run PyTorch latest with Python latest, use PyTorch scope to limit dependency issues - - { os: "macOS-12", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.1" } - - { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.1" } - - { os: "windows-2022", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.1" } + - { os: "macOS-14", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.5.1" } + - { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.5.1" } + - { os: "windows-2022", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.5.1" } # "oldest" versions tests, only on minimum Python - - { os: "macOS-13", pkg-name: "lightning", python-version: "3.9", pytorch-version: "2.1", requires: "oldest" } + - { os: "macOS-14", pkg-name: "lightning", python-version: "3.9", pytorch-version: "2.1", requires: "oldest" } - { os: "ubuntu-20.04", pkg-name: "lightning", @@ -103,7 +106,10 @@ jobs: - name: Set min. dependencies if: ${{ matrix.requires == 'oldest' }} - run: python .actions/assistant.py replace_oldest_ver + run: | + python .actions/assistant.py replace_oldest_ver + pip install "cython<3.0" wheel + pip install "pyyaml==5.4" --no-build-isolation - name: Adjust PyTorch versions in requirements files if: ${{ matrix.requires != 'oldest' }} @@ -208,7 +214,7 @@ jobs: coverage xml - name: Upload coverage to Codecov - uses: codecov/codecov-action@v4 + uses: codecov/codecov-action@v5 # see: https://github.com/actions/toolkit/issues/399 continue-on-error: true with: diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index 6df2b8cbb73d3..09ae3adc45ac6 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -43,10 +43,11 @@ jobs: include: # We only release one docker image per PyTorch version. # Make sure the matrix here matches the one below. - - { python_version: "3.11", pytorch_version: "2.1", cuda_version: "12.1.0" } + - { python_version: "3.10", pytorch_version: "2.1", cuda_version: "12.1.0" } - { python_version: "3.11", pytorch_version: "2.2", cuda_version: "12.1.0" } - { python_version: "3.11", pytorch_version: "2.3", cuda_version: "12.1.0" } - - { python_version: "3.12", pytorch_version: "2.4", cuda_version: "12.1.0" } + - { python_version: "3.11", pytorch_version: "2.4", cuda_version: "12.1.0" } + - { python_version: "3.12", pytorch_version: "2.5", cuda_version: "12.1.0" } steps: - uses: actions/checkout@v4 with: @@ -103,10 +104,11 @@ jobs: include: # These are the base images for PL release docker images. # Make sure the matrix here matches the one above. - - { python_version: "3.11", pytorch_version: "2.1", cuda_version: "12.1.0" } - - { python_version: "3.11", pytorch_version: "2.2", cuda_version: "12.1.0" } - - { python_version: "3.11", pytorch_version: "2.3", cuda_version: "12.1.0" } - - { python_version: "3.12", pytorch_version: "2.4", cuda_version: "12.1.0" } + - { python_version: "3.10", pytorch_version: "2.1.2", cuda_version: "12.1.0" } + - { python_version: "3.11", pytorch_version: "2.2.2", cuda_version: "12.1.0" } + - { python_version: "3.11", pytorch_version: "2.3.1", cuda_version: "12.1.0" } + - { python_version: "3.11", pytorch_version: "2.4.1", cuda_version: "12.1.0" } + - { python_version: "3.12", pytorch_version: "2.5.1", cuda_version: "12.1.0" } steps: - uses: actions/checkout@v4 - uses: docker/setup-buildx-action@v3 @@ -115,6 +117,12 @@ jobs: with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_PASSWORD }} + + - name: shorten Torch version + run: | + # convert 1.10.2 to 1.10 + pt_version=$(echo ${{ matrix.pytorch_version }} | cut -d. -f1,2) + echo "PT_VERSION=$pt_version" >> $GITHUB_ENV - uses: docker/build-push-action@v6 with: build-args: | @@ -123,7 +131,7 @@ jobs: CUDA_VERSION=${{ matrix.cuda_version }} file: dockers/base-cuda/Dockerfile push: ${{ env.PUSH_NIGHTLY }} - tags: "pytorchlightning/pytorch_lightning:base-cuda-py${{ matrix.python_version }}-torch${{ matrix.pytorch_version }}-cuda${{ matrix.cuda_version }}" + tags: "pytorchlightning/pytorch_lightning:base-cuda-py${{ matrix.python_version }}-torch${{ env.PT_VERSION }}-cuda${{ matrix.cuda_version }}" timeout-minutes: 95 - uses: ravsamhq/notify-slack-action@v2 if: failure() && env.PUSH_NIGHTLY == 'true' diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml index 8f385fcb39fd7..4443ff9d42a4a 100644 --- a/.github/workflows/docs-build.yml +++ b/.github/workflows/docs-build.yml @@ -129,11 +129,12 @@ jobs: run: echo "ARTIFACT_DAYS=7" >> $GITHUB_ENV - name: Upload built docs if: ${{ matrix.target == 'html' }} - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: docs-${{ matrix.pkg-name }}-${{ github.sha }} path: docs/build/html/ retention-days: ${{ env.ARTIFACT_DAYS }} + include-hidden-files: true #- name: Dump handy wheels # if: github.event_name == 'push' && github.ref == 'refs/heads/master' @@ -157,7 +158,7 @@ jobs: # use input if dispatch or git tag VERSION: ${{ inputs.version || github.ref_name }} steps: - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4 with: name: docs-${{ matrix.pkg-name }}-${{ github.sha }} path: docs/build/html/ diff --git a/.github/workflows/docs-tutorials.yml b/.github/workflows/docs-tutorials.yml index e4d78483fa81b..5879a7dd58744 100644 --- a/.github/workflows/docs-tutorials.yml +++ b/.github/workflows/docs-tutorials.yml @@ -48,7 +48,7 @@ jobs: - name: Create Pull Request if: ${{ github.event_name != 'pull_request' && env.SHA_ACTUAL != env.SHA_LATEST }} - uses: peter-evans/create-pull-request@v6 + uses: peter-evans/create-pull-request@v7 with: title: "docs: update ref to latest tutorials" committer: GitHub diff --git a/.github/workflows/release-nightly.yml b/.github/workflows/release-nightly.yml index 9578f84b87093..396e485b90065 100644 --- a/.github/workflows/release-nightly.yml +++ b/.github/workflows/release-nightly.yml @@ -44,6 +44,7 @@ jobs: with: name: nightly-packages-${{ github.sha }} path: dist + include-hidden-files: true publish-packages: runs-on: ubuntu-22.04 diff --git a/.github/workflows/release-pkg.yml b/.github/workflows/release-pkg.yml index a11751c13790e..c7828d70f7103 100644 --- a/.github/workflows/release-pkg.yml +++ b/.github/workflows/release-pkg.yml @@ -38,7 +38,7 @@ jobs: if: github.event_name == 'release' steps: - uses: actions/checkout@v4 - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4 with: name: dist-packages-${{ github.sha }} path: dist @@ -104,7 +104,7 @@ jobs: - name: Create Pull Request if: github.event_name != 'pull_request' - uses: peter-evans/create-pull-request@v6 + uses: peter-evans/create-pull-request@v7 with: title: "Bump lightning ver `${{ env.TAG }}`" committer: GitHub @@ -140,7 +140,7 @@ jobs: name: ["FABRIC", "PYTORCH", "LIGHTNING"] steps: - uses: actions/checkout@v4 # needed for local action below - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4 with: name: dist-packages-${{ github.sha }} path: dist @@ -165,7 +165,7 @@ jobs: name: ["FABRIC", "PYTORCH", "LIGHTNING"] steps: - uses: actions/checkout@v4 # needed for local action below - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4 with: name: dist-packages-${{ github.sha }} path: dist diff --git a/.github/workflows/tpu-tests.yml b/.github/workflows/tpu-tests.yml.disabled similarity index 100% rename from .github/workflows/tpu-tests.yml rename to .github/workflows/tpu-tests.yml.disabled diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cb7604831767b..c5e65de1d7eb7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -58,7 +58,7 @@ repos: #args: ["--write-changes"] # uncomment if you want to get automatic fixing - repo: https://github.com/PyCQA/docformatter - rev: v1.7.5 + rev: 06907d0267368b49b9180eed423fae5697c1e909 # todo: fix for docformatter after last 1.7.5 hooks: - id: docformatter additional_dependencies: [tomli] @@ -74,7 +74,7 @@ repos: hooks: # try to fix what is possible - id: ruff - args: ["--fix"] + args: ["--fix", "--unsafe-fixes"] # perform formatting updates - id: ruff-format # validate if all is fine with preview mode diff --git a/README.md b/README.md index d0c5a26bcab7e..aa58c6d8a585a 100644 --- a/README.md +++ b/README.md @@ -585,7 +585,6 @@ Lightning is rigorously tested across multiple CPUs, GPUs and TPUs and against m | System / PyTorch ver. | 1.13 | 2.0 | 2.1 | | :--------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:| | Linux py3.9 \[GPUs\] | | | [![Build Status](https://dev.azure.com/Lightning-AI/lightning/_apis/build/status%2Fpytorch-lightning%20%28GPUs%29?branchName=master)](https://dev.azure.com/Lightning-AI/lightning/_build/latest?definitionId=24&branchName=master) | -| Linux py3.9 \[TPUs\] | | [![Test PyTorch - TPU](https://github.com/Lightning-AI/lightning/actions/workflows/tpu-tests.yml/badge.svg)](https://github.com/Lightning-AI/lightning/actions/workflows/tpu-tests.yml) | | | Linux (multiple Python versions) | [![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml) | [![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml) | [![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml) | | OSX (multiple Python versions) | [![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml) | [![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml) | [![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml) | | Windows (multiple Python versions) | [![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml) | [![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml) | [![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml) | diff --git a/_notebooks b/_notebooks index e0720299da014..1e0e807329216 160000 --- a/_notebooks +++ b/_notebooks @@ -1 +1 @@ -Subproject commit e0720299da014bfaaeb50dea6778b962e28ca69d +Subproject commit 1e0e80732921606b641a4ab0e2eeebc93a60308f diff --git a/dockers/base-cuda/Dockerfile b/dockers/base-cuda/Dockerfile index 9153fd33ebef5..0e56f2fa93bd9 100644 --- a/dockers/base-cuda/Dockerfile +++ b/dockers/base-cuda/Dockerfile @@ -59,7 +59,6 @@ RUN \ add-apt-repository ppa:deadsnakes/ppa && \ apt-get install -y \ python${PYTHON_VERSION} \ - python${PYTHON_VERSION}-distutils \ python${PYTHON_VERSION}-dev \ && \ update-alternatives --install /usr/bin/python${PYTHON_VERSION%%.*} python${PYTHON_VERSION%%.*} /usr/bin/python${PYTHON_VERSION} 1 && \ @@ -79,6 +78,8 @@ RUN \ curl https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} && \ # Disable cache \ pip config set global.cache-dir false && \ + # Install recent setuptools to obtain pkg_resources \ + pip install setuptools==75.6.0 && \ # set particular PyTorch version \ pip install -q wget packaging && \ python -m wget https://raw.githubusercontent.com/Lightning-AI/utilities/main/scripts/adjust-torch-versions.py && \ diff --git a/dockers/docs/Dockerfile b/dockers/docs/Dockerfile index 46882c2dc50bf..ec590bf182ee2 100644 --- a/dockers/docs/Dockerfile +++ b/dockers/docs/Dockerfile @@ -44,7 +44,7 @@ RUN \ dvipng \ texlive-pictures \ python3 \ - python3-distutils \ + python3-setuptools \ python3-dev \ && \ update-alternatives --install /usr/bin/python python /usr/bin/python3 1 && \ diff --git a/dockers/release/Dockerfile b/dockers/release/Dockerfile index 9bc6c0104574e..d791d8875e8bc 100644 --- a/dockers/release/Dockerfile +++ b/dockers/release/Dockerfile @@ -39,7 +39,7 @@ RUN \ fi && \ # otherwise there is collision with folder name and pkg name on Pypi cd pytorch-lightning && \ - pip install setuptools && \ + pip install setuptools==75.6.0 && \ PACKAGE_NAME=lightning pip install '.[extra,loggers,strategies]' --no-cache-dir && \ PACKAGE_NAME=pytorch pip install '.[extra,loggers,strategies]' --no-cache-dir && \ cd .. && \ diff --git a/docs/source-fabric/_static/images/icon.svg b/docs/source-fabric/_static/images/icon.svg index e88fc19036178..3272f7f87d0fc 100644 --- a/docs/source-fabric/_static/images/icon.svg +++ b/docs/source-fabric/_static/images/icon.svg @@ -1,9 +1,12 @@ - - - - - - - - + + + + + + + + + + + diff --git a/docs/source-fabric/_static/images/logo-large.svg b/docs/source-fabric/_static/images/logo-large.svg index 39531f95e9dba..b4814805e2ddf 100644 --- a/docs/source-fabric/_static/images/logo-large.svg +++ b/docs/source-fabric/_static/images/logo-large.svg @@ -1,9 +1,12 @@ - - - - - - - - + + + + + + + + + + + diff --git a/docs/source-fabric/_static/images/logo-small.svg b/docs/source-fabric/_static/images/logo-small.svg index 1f523a57c4a16..aac0b9618ab37 100644 --- a/docs/source-fabric/_static/images/logo-small.svg +++ b/docs/source-fabric/_static/images/logo-small.svg @@ -1,9 +1,12 @@ - - - - - - - - + + + + + + + + + + + diff --git a/docs/source-fabric/advanced/compile.rst b/docs/source-fabric/advanced/compile.rst index 17ba6e4ca9dc8..df79454f67a6f 100644 --- a/docs/source-fabric/advanced/compile.rst +++ b/docs/source-fabric/advanced/compile.rst @@ -115,9 +115,115 @@ always exclude the first call to ``forward()`` from your measurements, since it Compile median time: 0.0185 seconds Speedup: 1.4x - ---- +********************************************** +Apply torch.compile with ModelParallelStrategy +********************************************** + +:func:`torch.compile` can also be invoked as part of the `parallelize_fn` argument of :class:`~lightning.fabric.strategies.model_parallel.ModelParallelStrategy`. + +This is particularly handy when :func:`torch.compile` is used in combination with the `torch.distributed.tensor` API. + +Here is an example: + +.. code-block:: python + + import lightning as L + import torch + import torch.nn as nn + import torch.nn.functional as F + from lightning.pytorch.demos import Transformer + from lightning.fabric.strategies.model_parallel import ModelParallelStrategy + from torch.distributed._composable.fsdp.fully_shard import fully_shard + from torch.distributed.device_mesh import DeviceMesh + + def parallelize(model: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + for module in model.modules(): + if isinstance(module, (torch.nn.TransformerEncoderLayer, torch.nn.TransformerDecoderLayer)): + fully_shard(module, mesh=device_mesh) + + fully_shard(model, mesh=device_mesh) + + return torch.compile(model) + + def train(): + L.seed_everything(42) + + with torch.device("meta"): + model = Transformer( + vocab_size=50257, + nlayers=16, + nhid=4096, + ninp=1024, + nhead=32, + ) + + strategy = ModelParallelStrategy(data_parallel_size=4, tensor_parallel_size=1, parallelize_fn=parallelize) + + fabric = L.Fabric(precision="bf16-true", strategy=strategy) + fabric.launch() + + model = fabric.setup(model) + +The advantage here is that `parallelize` is called when sharding the model, +so :func:`torch.compile` is guaranteed to run on model shards and capture distributed operations. + +Also, when using other libraries like `torch ao `_ +that need to be applied in a similar fashion, it's easy to reason about the sequence of calls +needed to achieve the equivalent of `compile(distributed(quantized(model)))`: + +.. code-block:: python + + import lightning as L + import torch + import torch.nn as nn + import torch.nn.functional as F + from lightning.pytorch.demos import Transformer + from torch.distributed._composable.fsdp.fully_shard import fully_shard + from torch.distributed.device_mesh import DeviceMesh + from torchao.float8 import Float8LinearConfig, convert_to_float8_training + + def parallelize(model: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + float8_config = Float8LinearConfig( + pad_inner_dim=True, + ) + + def module_filter_fn(mod: torch.nn.Module, fqn: str): + return fqn != "decoder" + + convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn) + + for module in model.modules(): + if isinstance(module, (torch.nn.TransformerEncoderLayer, torch.nn.TransformerDecoderLayer)): + fully_shard(module, mesh=device_mesh) + + fully_shard(model, mesh=device_mesh) + + return torch.compile(model) + + def train(): + L.seed_everything(42) + + with torch.device("meta"): + model = Transformer( + vocab_size=50257, + nlayers=16, + nhid=4096, + ninp=1024, + nhead=32, + ) + + strategy = ModelParallelStrategy(data_parallel_size=4, tensor_parallel_size=1, parallelize_fn=parallelize) + + fabric = L.Fabric(precision="bf16-true", strategy=strategy) + fabric.launch() + + model = fabric.setup(model) + +For a full example, see our `FP8 Distributed Transformer example `_. + +---- ****************** Avoid graph breaks diff --git a/docs/source-fabric/fundamentals/launch.rst b/docs/source-fabric/fundamentals/launch.rst index f8c0deecf4e25..81b6cd9d186f1 100644 --- a/docs/source-fabric/fundamentals/launch.rst +++ b/docs/source-fabric/fundamentals/launch.rst @@ -116,7 +116,7 @@ This is essentially the same as running ``python path/to/your/script.py``, but i machine. --precision [16-mixed|bf16-mixed|32-true|64-true|64|32|16|bf16] Double precision (``64-true`` or ``64``), - full precision (``32-true`` or ``64``), half + full precision (``32-true`` or ``32``), half precision (``16-mixed`` or ``16``) or bfloat16 precision (``bf16-mixed`` or ``bf16``) diff --git a/docs/source-pytorch/_static/images/icon.svg b/docs/source-pytorch/_static/images/icon.svg index 481762a961dda..aac0b9618ab37 100644 --- a/docs/source-pytorch/_static/images/icon.svg +++ b/docs/source-pytorch/_static/images/icon.svg @@ -1,3 +1,12 @@ - + + + + + + + + + + diff --git a/docs/source-pytorch/accelerators/gpu_intermediate.rst b/docs/source-pytorch/accelerators/gpu_intermediate.rst index 023fd02c185a8..2774a4cf8fc6f 100644 --- a/docs/source-pytorch/accelerators/gpu_intermediate.rst +++ b/docs/source-pytorch/accelerators/gpu_intermediate.rst @@ -26,7 +26,7 @@ Lightning supports multiple ways of doing distributed training. If you request multiple GPUs or nodes without setting a strategy, DDP will be automatically used. For a deeper understanding of what Lightning is doing, feel free to read this -`guide `_. +`guide `_. ---- diff --git a/docs/source-pytorch/accelerators/tpu_advanced.rst b/docs/source-pytorch/accelerators/tpu_advanced.rst index e410c6e82539f..d74f9b07374c9 100644 --- a/docs/source-pytorch/accelerators/tpu_advanced.rst +++ b/docs/source-pytorch/accelerators/tpu_advanced.rst @@ -52,7 +52,7 @@ Example: model = WeightSharingModule() trainer = Trainer(max_epochs=1, accelerator="tpu") -See `XLA Documentation `_ +See `XLA Documentation `_ ---- @@ -61,4 +61,4 @@ XLA XLA is the library that interfaces PyTorch with the TPUs. For more information check out `XLA `_. -Guide for `troubleshooting XLA `_ +Guide for `troubleshooting XLA `_ diff --git a/docs/source-pytorch/accelerators/tpu_basic.rst b/docs/source-pytorch/accelerators/tpu_basic.rst index fb4e2b7bde244..217b76106aea9 100644 --- a/docs/source-pytorch/accelerators/tpu_basic.rst +++ b/docs/source-pytorch/accelerators/tpu_basic.rst @@ -108,7 +108,7 @@ There are cases in which training on TPUs is slower when compared with GPUs, for - XLA Graph compilation during the initial steps `Reference `_ - Some tensor ops are not fully supported on TPU, or not supported at all. These operations will be performed on CPU (context switch). -The official PyTorch XLA `performance guide `_ +The official PyTorch XLA `performance guide `_ has more detailed information on how PyTorch code can be optimized for TPU. In particular, the -`metrics report `_ allows +`metrics report `_ allows one to identify operations that lead to context switching. diff --git a/docs/source-pytorch/accelerators/tpu_faq.rst b/docs/source-pytorch/accelerators/tpu_faq.rst index f4b2c60633d26..766f9dcacb32e 100644 --- a/docs/source-pytorch/accelerators/tpu_faq.rst +++ b/docs/source-pytorch/accelerators/tpu_faq.rst @@ -40,9 +40,9 @@ Unsupported datatype transfer to TPUs? .. code-block:: - File "/usr/local/lib/python3.8/dist-packages/torch_xla/utils/utils.py", line 205, in _for_each_instance_rewrite + File "/usr/local/lib/python3.9/dist-packages/torch_xla/utils/utils.py", line 205, in _for_each_instance_rewrite v = _for_each_instance_rewrite(result.__dict__[k], select_fn, fn, rwmap) - File "/usr/local/lib/python3.8/dist-packages/torch_xla/utils/utils.py", line 206, in _for_each_instance_rewrite + File "/usr/local/lib/python3.9/dist-packages/torch_xla/utils/utils.py", line 206, in _for_each_instance_rewrite result.__dict__[k] = v TypeError: 'mappingproxy' object does not support item assignment @@ -78,7 +78,7 @@ A lot of PyTorch operations aren't lowered to XLA, which could lead to significa These operations are moved to the CPU memory and evaluated, and then the results are transferred back to the XLA device(s). By using the `xla_debug` Strategy, users could create a metrics report to diagnose issues. -The report includes things like (`XLA Reference `_): +The report includes things like (`XLA Reference `_): * how many times we issue XLA compilations and time spent on issuing. * how many times we execute and time spent on execution diff --git a/docs/source-pytorch/advanced/compile.rst b/docs/source-pytorch/advanced/compile.rst index d5bd333c041b3..16fe91ca282df 100644 --- a/docs/source-pytorch/advanced/compile.rst +++ b/docs/source-pytorch/advanced/compile.rst @@ -138,6 +138,122 @@ always exclude the first call to ``forward()``/``*_step()`` from your measuremen ---- +************************************** +Apply torch.compile in configure_model +************************************** + +:func:`torch.compile` can also be invoked as part of the :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` hook. + +This is particularly handy when :func:`torch.compile` is used in combination with :class:`~lightning.pytorch.strategies.model_parallel.ModelParallelStrategy`. + +Here is an example: + +.. code-block:: python + + import lightning as L + import torch + import torch.nn as nn + import torch.nn.functional as F + from lightning.pytorch.demos import Transformer + from lightning.pytorch.strategies.model_parallel import ModelParallelStrategy + from torch.distributed.device_mesh import DeviceMesh + from torch.distributed._composable.fsdp.fully_shard import fully_shard + + class LanguageModel(L.LightningModule): + def __init__(self, vocab_size): + super().__init__() + self.vocab_size = vocab_size + self.model = None + + def configure_model(self): + if self.model is not None: + return + + with torch.device("meta"): + model = Transformer( + vocab_size=self.vocab_size, + nlayers=16, + nhid=4096, + ninp=1024, + nhead=32, + ) + + for module in model.modules(): + if isinstance(module, (nn.TransformerEncoderLayer, nn.TransformerDecoderLayer)): + fully_shard(module, mesh=self.device_mesh) + + fully_shard(model, mesh=self.device_mesh) + + self.model = torch.compile(model) + + def training_step(self, batch): + input, target = batch + output = self.model(input, target) + loss = F.nll_loss(output, target.view(-1)) + self.log("train_loss", loss) + return loss + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=1e-4) + +The advantage here is that `configure_model` is called when sharding the model, +so :func:`torch.compile` is guaranteed to run on model shards and capture distributed operations. + +Also, when using other libraries like `torch ao `_ +that need to be applied in a similar fashion, it's easy to reason about the sequence of calls +needed to achieve the equivalent of `compile(distributed(quantized(model)))`: + +.. code-block:: python + + import lightning as L + import torch + import torch.nn as nn + import torch.nn.functional as F + from lightning.pytorch.demos import Transformer + from lightning.pytorch.strategies.model_parallel import ModelParallelStrategy + from torch.distributed._composable.fsdp.fully_shard import fully_shard + from torch.distributed.device_mesh import DeviceMesh + from torchao.float8 import Float8LinearConfig, convert_to_float8_training + + class LanguageModel(L.LightningModule): + def __init__(self, vocab_size): + super().__init__() + self.vocab_size = vocab_size + self.model = None + + def configure_model(self): + if self.model is not None: + return + + with torch.device("meta"): + model = Transformer( + vocab_size=self.vocab_size, + nlayers=16, + nhid=4096, + ninp=1024, + nhead=32, + ) + + float8_config = Float8LinearConfig( + pad_inner_dim=True, + ) + + def module_filter_fn(mod: torch.nn.Module, fqn: str): + return fqn != "decoder" + + convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn) + + for module in model.modules(): + if isinstance(module, (nn.TransformerEncoderLayer, nn.TransformerDecoderLayer)): + fully_shard(module, mesh=self.device_mesh) + + fully_shard(model, mesh=self.device_mesh) + + self.model = torch.compile(model) + +For a full example, see our `FP8 Distributed Transformer example `_. + +---- ****************** Avoid graph breaks @@ -253,8 +369,8 @@ Limitations There are a few limitations you should be aware of when using ``torch.compile`` **in conjunction with the Trainer**: -* The Trainer currently does not reapply ``torch.compile`` over DDP/FSDP, meaning distributed operations can't benefit from speed ups at the moment. - This limitation will be lifted in the future. +* The Trainer currently does not reapply ``torch.compile`` over :class:`~lightning.pytorch.strategies.DDPStrategy` and :class:`~lightning.pytorch.strategies.FSDPStrategy`, meaning distributed operations can't benefit from speed ups at the moment. + This limitation can be avoided by using :class:`~lightning.pytorch.strategies.model_parallel.ModelParallelStrategy`, as described in `Apply torch.compile in configure_model`_ above. * In some cases, using ``self.log()`` in your LightningModule will cause compilation errors. Until addressed, you can work around these issues by applying ``torch.compile`` to the submodule(s) of your LightningModule rather than to the entire LightningModule at once. diff --git a/docs/source-pytorch/advanced/post_training_quantization.rst b/docs/source-pytorch/advanced/post_training_quantization.rst index 504a57a4191cf..f925c6ccd47b4 100644 --- a/docs/source-pytorch/advanced/post_training_quantization.rst +++ b/docs/source-pytorch/advanced/post_training_quantization.rst @@ -33,7 +33,7 @@ Installation Prerequisites ============= -Python version: 3.8, 3.9, 3.10 +Python version: 3.9, 3.10 Install Intel® Neural Compressor ================================ diff --git a/docs/source-pytorch/common/checkpointing_basic.rst b/docs/source-pytorch/common/checkpointing_basic.rst index 5c74178f0eaaa..1026e972849ef 100644 --- a/docs/source-pytorch/common/checkpointing_basic.rst +++ b/docs/source-pytorch/common/checkpointing_basic.rst @@ -20,6 +20,13 @@ PyTorch Lightning checkpoints are fully usable in plain PyTorch. ---- +.. important:: + + **Important Update: Deprecated Method** + + Starting from PyTorch Lightning v1.0.0, the `resume_from_checkpoint` argument has been deprecated. To resume training from a checkpoint, use the `ckpt_path` argument in the `fit()` method. + Please update your code accordingly to avoid potential compatibility issues. + ************************ Contents of a checkpoint ************************ @@ -197,16 +204,31 @@ You can disable checkpointing by passing: ---- + ********************* Resume training state ********************* If you don't just want to load weights, but instead restore the full training, do the following: +Correct usage: + .. code-block:: python model = LitModel() trainer = Trainer() # automatically restores model, epoch, step, LR schedulers, etc... - trainer.fit(model, ckpt_path="some/path/to/my_checkpoint.ckpt") + trainer.fit(model, ckpt_path="path/to/your/checkpoint.ckpt") + +.. warning:: + + The argument `resume_from_checkpoint` has been deprecated in versions of PyTorch Lightning >= 1.0.0. + To resume training from a checkpoint, use the `ckpt_path` argument in the `fit()` method instead. + +Incorrect (deprecated) usage: + +.. code-block:: python + + trainer = Trainer(resume_from_checkpoint="path/to/your/checkpoint.ckpt") + trainer.fit(model) diff --git a/docs/source-pytorch/common/index.rst b/docs/source-pytorch/common/index.rst index 738e971aec532..e0492a7e747fe 100644 --- a/docs/source-pytorch/common/index.rst +++ b/docs/source-pytorch/common/index.rst @@ -23,6 +23,7 @@ ../data/data ../model/own_your_loop ../advanced/model_init + ../common/tbptt ############# @@ -202,6 +203,13 @@ How-to Guides :col_css: col-md-4 :height: 180 +.. displayitem:: + :header: Truncated Back-Propagation Through Time + :description: Efficiently step through time when training recurrent models + :button_link: ../common/tbptt.html + :col_css: col-md-4 + :height: 180 + .. raw:: html diff --git a/docs/source-pytorch/common/progress_bar.rst b/docs/source-pytorch/common/progress_bar.rst index e0c29fccdc494..106c2289e5c7b 100644 --- a/docs/source-pytorch/common/progress_bar.rst +++ b/docs/source-pytorch/common/progress_bar.rst @@ -36,6 +36,10 @@ You can update ``refresh_rate`` (rate (number of batches) at which the progress trainer = Trainer(callbacks=[TQDMProgressBar(refresh_rate=10)]) +.. note:: + + The ``smoothing`` option has no effect when using the default implementation of :class:`~lightning.pytorch.callbacks.TQDMProgressBar`, as the progress bar is updated using the ``bar.refresh()`` method instead of ``bar.update()``. This can cause the progress bar to become desynchronized with the actual progress. To avoid this issue, you can use the ``bar.update()`` method instead, but this may require customizing the :class:`~lightning.pytorch.callbacks.TQDMProgressBar` class. + By default the training progress bar is reset (overwritten) at each new epoch. If you wish for a new progress bar to be displayed at the end of every epoch, set :paramref:`TQDMProgressBar.leave ` to ``True``. diff --git a/docs/source-pytorch/common/tbptt.rst b/docs/source-pytorch/common/tbptt.rst new file mode 100644 index 0000000000000..063ef8c33d319 --- /dev/null +++ b/docs/source-pytorch/common/tbptt.rst @@ -0,0 +1,59 @@ +############################################## +Truncated Backpropagation Through Time (TBPTT) +############################################## + +Truncated Backpropagation Through Time (TBPTT) performs backpropogation every k steps of +a much longer sequence. This is made possible by passing training batches +split along the time-dimensions into splits of size k to the +``training_step``. In order to keep the same forward propagation behavior, all +hidden states should be kept in-between each time-dimension split. + + +.. code-block:: python + + import torch + import torch.optim as optim + import pytorch_lightning as pl + from pytorch_lightning import LightningModule + + class LitModel(LightningModule): + + def __init__(self): + super().__init__() + + # 1. Switch to manual optimization + self.automatic_optimization = False + + self.truncated_bptt_steps = 10 + self.my_rnn = ParityModuleRNN() # Define RNN model using ParityModuleRNN + + # 2. Remove the `hiddens` argument + def training_step(self, batch, batch_idx): + + # 3. Split the batch in chunks along the time dimension + split_batches = split_batch(batch, self.truncated_bptt_steps) + + batch_size = 10 + hidden_dim = 20 + hiddens = torch.zeros(1, batch_size, hidden_dim, device=self.device) + for split_batch in range(split_batches): + # 4. Perform the optimization in a loop + loss, hiddens = self.my_rnn(split_batch, hiddens) + self.backward(loss) + self.optimizer.step() + self.optimizer.zero_grad() + + # 5. "Truncate" + hiddens = hiddens.detach() + + # 6. Remove the return of `hiddens` + # Returning loss in manual optimization is not needed + return None + + def configure_optimizers(self): + return optim.Adam(self.my_rnn.parameters(), lr=0.001) + + if __name__ == "__main__": + model = LitModel() + trainer = pl.Trainer(max_epochs=5) + trainer.fit(model, train_dataloader) # Define your own dataloader diff --git a/docs/source-pytorch/community/governance.rst b/docs/source-pytorch/community/governance.rst index 36b613d32d5be..58a731d0576e3 100644 --- a/docs/source-pytorch/community/governance.rst +++ b/docs/source-pytorch/community/governance.rst @@ -18,7 +18,7 @@ Role: All final decisions related to Lightning. Maintainers ----------- -- Adrian Wälchli (`awaelchli `_) +- Luca Antiga (`lantiga `_) - Jirka Borovec (`Borda `_) - Justus Schock (`justusschock `_) @@ -32,6 +32,7 @@ Emeritus Maintainers Alumni ------ +- Adrian Wälchli (`awaelchli `_) - Carlos Mocholí (`carmocca `_) - Akihiro Nitta (`akihironitta `_) - Ananth Subramaniam (`ananthsub `_) diff --git a/docs/source-pytorch/conf.py b/docs/source-pytorch/conf.py index 107f58c5797b6..1488a3b625d04 100644 --- a/docs/source-pytorch/conf.py +++ b/docs/source-pytorch/conf.py @@ -462,7 +462,9 @@ def _load_py_module(name: str, location: str) -> ModuleType: ("py:obj", "lightning.pytorch.utilities.memory.is_out_of_cpu_memory"), ("py:func", "lightning.pytorch.utilities.rank_zero.rank_zero_only"), ("py:class", "lightning.pytorch.utilities.types.LRSchedulerConfig"), - ("py:class", "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig"), + ("py:class", "lightning.pytorch.utilities.types.LRSchedulerConfigType"), + ("py:class", "lightning.pytorch.utilities.types.OptimizerConfigType"), + ("py:class", "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfigType"), ("py:class", "lightning_habana.pytorch.plugins.precision.HPUPrecisionPlugin"), ("py:class", "lightning_habana.pytorch.strategies.HPUDDPStrategy"), ("py:class", "lightning_habana.pytorch.strategies.HPUParallelStrategy"), @@ -641,6 +643,7 @@ def package_list_from_file(file): r"starter/installation.html$", r"^../common/trainer.html#trainer-flags$", "https://deepgenerativemodels.github.io/assets/slides/cs236_lecture11.pdf", + "https://developer.habana.ai", # returns 403 error but redirects to intel.com documentation "https://www.intel.com/content/www/us/en/products/docs/processors/what-is-a-gpu.html", "https://www.microsoft.com/en-us/research/blog/zero-infinity-and-deepspeed-unlocking-unprecedented-model-scale-for-deep-learning-training/", # noqa: E501 "https://stackoverflow.com/questions/66640705/how-can-i-install-grpcio-on-an-apple-m1-silicon-laptop", diff --git a/docs/source-pytorch/levels/expert.rst b/docs/source-pytorch/levels/expert.rst index bb0fbf25a802c..981f47d3c829f 100644 --- a/docs/source-pytorch/levels/expert.rst +++ b/docs/source-pytorch/levels/expert.rst @@ -41,7 +41,7 @@ Customize and extend Lightning for things like custom hardware or distributed st :header: Level 24: Add a new accelerator or Strategy :description: Integrate a new accelerator or distributed strategy. :col_css: col-md-6 - :button_link: expert_level_27.html + :button_link: expert_level_24.html :height: 150 :tag: expert diff --git a/docs/source-pytorch/levels/expert_level_25.rst b/docs/source-pytorch/levels/expert_level_24.rst similarity index 100% rename from docs/source-pytorch/levels/expert_level_25.rst rename to docs/source-pytorch/levels/expert_level_24.rst diff --git a/docs/source-pytorch/upgrade/sections/2_0_regular.rst b/docs/source-pytorch/upgrade/sections/2_0_regular.rst index 192f20bc669b9..2f94ef7ab66fd 100644 --- a/docs/source-pytorch/upgrade/sections/2_0_regular.rst +++ b/docs/source-pytorch/upgrade/sections/2_0_regular.rst @@ -6,7 +6,7 @@ - Then - Ref - * - used PyTorch 3.11 + * - used PyTorch 1.11 - upgrade to PyTorch 2.1 or higher - `PR18691`_ diff --git a/docs/source-pytorch/visualize/supported_exp_managers.rst b/docs/source-pytorch/visualize/supported_exp_managers.rst index 42a0e6c9a85ed..e26514e9747c4 100644 --- a/docs/source-pytorch/visualize/supported_exp_managers.rst +++ b/docs/source-pytorch/visualize/supported_exp_managers.rst @@ -134,7 +134,7 @@ Here's the full documentation for the :class:`~lightning.pytorch.loggers.TensorB Weights and Biases ================== -To use `Weights and Biases `_ (wandb) first install the wandb package: +To use `Weights and Biases `_ (wandb) first install the wandb package: .. code-block:: bash diff --git a/examples/fabric/build_your_own_trainer/run.py b/examples/fabric/build_your_own_trainer/run.py index 01044f5d94fa8..c0c2ff28ddc41 100644 --- a/examples/fabric/build_your_own_trainer/run.py +++ b/examples/fabric/build_your_own_trainer/run.py @@ -41,7 +41,8 @@ def training_step(self, batch, batch_idx: int): def configure_optimizers(self): optim = torch.optim.Adam(self.parameters(), lr=1e-4) - return optim, { + return { + "optimizer": optim, "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(optim, mode="max", verbose=True), "monitor": "val_accuracy", "interval": "epoch", diff --git a/examples/fabric/build_your_own_trainer/trainer.py b/examples/fabric/build_your_own_trainer/trainer.py index 7af01ede054a8..f4f31c114f084 100644 --- a/examples/fabric/build_your_own_trainer/trainer.py +++ b/examples/fabric/build_your_own_trainer/trainer.py @@ -1,7 +1,7 @@ import os -from collections.abc import Mapping +from collections.abc import Iterable, Mapping from functools import partial -from typing import Any, Iterable, List, Literal, Optional, Tuple, Union, cast +from typing import Any, Literal, Optional, Union, cast import lightning as L import torch @@ -19,11 +19,11 @@ def __init__( self, accelerator: Union[str, Accelerator] = "auto", strategy: Union[str, Strategy] = "auto", - devices: Union[List[int], str, int] = "auto", + devices: Union[list[int], str, int] = "auto", precision: Union[str, int] = "32-true", plugins: Optional[Union[str, Any]] = None, - callbacks: Optional[Union[List[Any], Any]] = None, - loggers: Optional[Union[Logger, List[Logger]]] = None, + callbacks: Optional[Union[list[Any], Any]] = None, + loggers: Optional[Union[Logger, list[Logger]]] = None, max_epochs: Optional[int] = 1000, max_steps: Optional[int] = None, grad_accum_steps: int = 1, @@ -465,7 +465,7 @@ def get_latest_checkpoint(checkpoint_dir: str) -> Optional[str]: def _parse_optimizers_schedulers( self, configure_optim_output - ) -> Tuple[ + ) -> tuple[ Optional[L.fabric.utilities.types.Optimizable], Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]], ]: diff --git a/examples/fabric/fp8_distributed_transformer/README.md b/examples/fabric/fp8_distributed_transformer/README.md new file mode 100644 index 0000000000000..e980d759bb3ff --- /dev/null +++ b/examples/fabric/fp8_distributed_transformer/README.md @@ -0,0 +1,39 @@ +## Distributed, Low-Precision Transformer Example + +This example shows how to use `ModelParallelStrategy` in `Fabric` to train a Transformer model minimizing memory usage, maximizing throughput, and distributing load across multiple GPUs. + +### Training Large Models and Memory Requirements + +One of the main challenges when training large models, like large language models (LLMs), is dealing with their memory footprint. LLMs can be so large that weights, activations, gradients and optimizer state don't fit a single GPU, so that they need to be distributed across multiple GPUs, and across multiple machines. There are multiple ways of distributing computations, among which fully-sharded data parallelism (FSDP) and tensor parallelism (TP). + +An additional way of reducing memory requirements is representing floating point numbers in weights and activations in low numerical precision, such as 16-bit (`bfloat16`), or 8-bit (`fp8`). This leads to savings in memory usage, as well as memory bandwidth usage (fewer bytes transferred from device memory to GPU cores in unit time). + +Roughly, reducing precision to `fp8` for linear layers can lead to 2x reduction in memory requirements and 1.6x improvement in throughput. Support for `fp8` weights and activations requires recent GPUs - Hopper, Ada Lovelace and above (e.g. H100, L4, L40). + +The introduction of tensor subclasses in PyTorch brought two new APIs that can be used to achieve memory savings and distributed training (as well as inference) in combination: + +- [torch ao](https://github.com/pytorch/ao) to execute linear layers in low numerical precision (`fp8` and other quantized formats) +- [dtensors](https://pytorch.org/docs/stable/distributed.tensor.html) to distribute models across GPUs, by combining TP and FSDP (referred to FSDP2 in PyTorch) + +Notably, `torch ao` introduces quantization and dequantization operations in the model that may result in slow-downs if not optimized. Using `torch.compile` after `torch ao` recovers performance by generating optimized kernels for those operations. + +### Vanilla Transformer Example + +This example shows how to train a vanilla Transformer model using `fp8` precision and the FSDP2 distributed strategy, and then optimize the resulting model through `torch.compile`. + +Specifically, we employ the `ModelParallelStrategy`, and use the `configure_model` hook to distribute the model using the PyTorch DTensor API. +In the same hook we also pass the model through the `torch ao` API (prior to FSDP2), as well as `torch.compile` (after FSDP2). + +The resulting code follows the PyTorch API closely, while also taking advantage of the rest of PyTorch Lightning. + +To execute the code directly just run: + +```bash +python train.py +``` + +### A Note on torch.compile + +Note that PyTorch Lightning also supports calling `torch.compile` on a `LightningModule` and passing it to the `Trainer`. + +While this works for simple cases, in order to get the most out of the combination of the latest distributed, quantization, and compile PyTorch API's, we recommend invoking `torch.compile` at the end of the `configure_model` hook, as shown in this example. diff --git a/examples/fabric/fp8_distributed_transformer/requirements.txt b/examples/fabric/fp8_distributed_transformer/requirements.txt new file mode 100644 index 0000000000000..ce00e191aa9c1 --- /dev/null +++ b/examples/fabric/fp8_distributed_transformer/requirements.txt @@ -0,0 +1 @@ +torchao>=0.7.0 diff --git a/examples/fabric/fp8_distributed_transformer/train.py b/examples/fabric/fp8_distributed_transformer/train.py new file mode 100644 index 0000000000000..ba88603268945 --- /dev/null +++ b/examples/fabric/fp8_distributed_transformer/train.py @@ -0,0 +1,100 @@ +import lightning as L +import torch +import torch.nn as nn +import torch.nn.functional as F +from lightning.fabric.strategies import ModelParallelStrategy +from lightning.pytorch.demos import Transformer, WikiText2 +from torch.distributed._composable.fsdp.fully_shard import fully_shard +from torch.distributed.device_mesh import DeviceMesh +from torch.utils.data import DataLoader +from torchao.float8 import Float8LinearConfig, convert_to_float8_training +from tqdm import tqdm + + +def configure_model(model: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + float8_config = Float8LinearConfig( + # pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly # noqa + pad_inner_dim=True, + ) + + def module_filter_fn(mod: torch.nn.Module, fqn: str): + # we skip the decoder because it typically vocabulary size + # is not divisible by 16 as required by float8 + return fqn != "decoder" + + convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn) + + for module in model.modules(): + if isinstance(module, (torch.nn.TransformerEncoderLayer, torch.nn.TransformerDecoderLayer)): + fully_shard(module, mesh=device_mesh) + + fully_shard(model, mesh=device_mesh) + + return torch.compile(model) + + +def train(): + L.seed_everything(42) + + batch_size = 8 + micro_batch_size = 1 + + max_steps = 100 + + dataset = WikiText2() + dataloader = DataLoader(dataset, num_workers=8, batch_size=micro_batch_size) + + with torch.device("meta"): + model = Transformer( + vocab_size=dataset.vocab_size, + nlayers=16, + nhid=4096, + ninp=1024, + nhead=32, + ) + + strategy = ModelParallelStrategy(data_parallel_size=4, tensor_parallel_size=1, parallelize_fn=configure_model) + + fabric = L.Fabric(precision="bf16-true", strategy=strategy) + fabric.launch() + + model = fabric.setup(model) + + optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) + optimizer = fabric.setup_optimizers(optimizer) + + dataloader = fabric.setup_dataloaders(dataloader) + + iterable = tqdm(enumerate(dataloader), total=len(dataloader)) if fabric.is_global_zero else enumerate(dataloader) + + steps = 0 + + for i, batch in iterable: + input, target = batch + + is_accumulating = i % (batch_size // micro_batch_size) != 0 + + with fabric.no_backward_sync(model, enabled=is_accumulating): + output = model(input, target) + loss = F.nll_loss(output, target.view(-1)) + fabric.backward(loss) + + if not is_accumulating: + fabric.clip_gradients(model, optimizer, max_norm=1.0) + optimizer.step() + optimizer.zero_grad() + steps += 1 + + if fabric.is_global_zero: + iterable.set_postfix_str(f"train_loss={loss.item():.2f}") + + if steps == max_steps: + break + + fabric.print(torch.cuda.memory_summary()) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("high") + + train() diff --git a/examples/fabric/reinforcement_learning/rl/agent.py b/examples/fabric/reinforcement_learning/rl/agent.py index fcf5bd0b9371d..16a4cd6d86c73 100644 --- a/examples/fabric/reinforcement_learning/rl/agent.py +++ b/examples/fabric/reinforcement_learning/rl/agent.py @@ -1,5 +1,4 @@ import math -from typing import Dict, Tuple import gymnasium as gym import torch @@ -43,7 +42,7 @@ def __init__(self, envs: gym.vector.SyncVectorEnv, act_fun: str = "relu", ortho_ layer_init(torch.nn.Linear(64, envs.single_action_space.n), std=0.01, ortho_init=ortho_init), ) - def get_action(self, x: Tensor, action: Tensor = None) -> Tuple[Tensor, Tensor, Tensor]: + def get_action(self, x: Tensor, action: Tensor = None) -> tuple[Tensor, Tensor, Tensor]: logits = self.actor(x) distribution = Categorical(logits=logits) if action is None: @@ -58,12 +57,12 @@ def get_greedy_action(self, x: Tensor) -> Tensor: def get_value(self, x: Tensor) -> Tensor: return self.critic(x) - def get_action_and_value(self, x: Tensor, action: Tensor = None) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + def get_action_and_value(self, x: Tensor, action: Tensor = None) -> tuple[Tensor, Tensor, Tensor, Tensor]: action, log_prob, entropy = self.get_action(x, action) value = self.get_value(x) return action, log_prob, entropy, value - def forward(self, x: Tensor, action: Tensor = None) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + def forward(self, x: Tensor, action: Tensor = None) -> tuple[Tensor, Tensor, Tensor, Tensor]: return self.get_action_and_value(x, action) @torch.no_grad() @@ -77,7 +76,7 @@ def estimate_returns_and_advantages( num_steps: int, gamma: float, gae_lambda: float, - ) -> Tuple[Tensor, Tensor]: + ) -> tuple[Tensor, Tensor]: next_value = self.get_value(next_obs).reshape(1, -1) advantages = torch.zeros_like(rewards) lastgaelam = 0 @@ -143,7 +142,7 @@ def __init__( self.avg_value_loss = MeanMetric(**torchmetrics_kwargs) self.avg_ent_loss = MeanMetric(**torchmetrics_kwargs) - def get_action(self, x: Tensor, action: Tensor = None) -> Tuple[Tensor, Tensor, Tensor]: + def get_action(self, x: Tensor, action: Tensor = None) -> tuple[Tensor, Tensor, Tensor]: logits = self.actor(x) distribution = Categorical(logits=logits) if action is None: @@ -158,12 +157,12 @@ def get_greedy_action(self, x: Tensor) -> Tensor: def get_value(self, x: Tensor) -> Tensor: return self.critic(x) - def get_action_and_value(self, x: Tensor, action: Tensor = None) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + def get_action_and_value(self, x: Tensor, action: Tensor = None) -> tuple[Tensor, Tensor, Tensor, Tensor]: action, log_prob, entropy = self.get_action(x, action) value = self.get_value(x) return action, log_prob, entropy, value - def forward(self, x: Tensor, action: Tensor = None) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + def forward(self, x: Tensor, action: Tensor = None) -> tuple[Tensor, Tensor, Tensor, Tensor]: return self.get_action_and_value(x, action) @torch.no_grad() @@ -177,7 +176,7 @@ def estimate_returns_and_advantages( num_steps: int, gamma: float, gae_lambda: float, - ) -> Tuple[Tensor, Tensor]: + ) -> tuple[Tensor, Tensor]: next_value = self.get_value(next_obs).reshape(1, -1) advantages = torch.zeros_like(rewards) lastgaelam = 0 @@ -193,7 +192,7 @@ def estimate_returns_and_advantages( returns = advantages + values return returns, advantages - def training_step(self, batch: Dict[str, Tensor]): + def training_step(self, batch: dict[str, Tensor]): # Get actions and values given the current observations _, newlogprob, entropy, newvalue = self(batch["obs"], batch["actions"].long()) logratio = newlogprob - batch["logprobs"] diff --git a/examples/fabric/reinforcement_learning/rl/utils.py b/examples/fabric/reinforcement_learning/rl/utils.py index 7585eda616ee4..4c5a8066b359f 100644 --- a/examples/fabric/reinforcement_learning/rl/utils.py +++ b/examples/fabric/reinforcement_learning/rl/utils.py @@ -1,7 +1,6 @@ import argparse import math import os -from distutils.util import strtobool from typing import TYPE_CHECKING, Optional, Union import gymnasium as gym @@ -12,6 +11,23 @@ from rl.agent import PPOAgent, PPOLightningAgent +def strtobool(val): + """Convert a string representation of truth to true (1) or false (0). + + True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'. + Raises ValueError if 'val' is anything else. + + Note: taken from distutils after its deprecation. + + """ + val = val.lower() + if val in ("y", "yes", "t", "true", "on", "1"): + return 1 + if val in ("n", "no", "f", "false", "off", "0"): + return 0 + raise ValueError(f"invalid truth value {val!r}") + + def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--exp-name", type=str, default="default", help="the name of this experiment") diff --git a/examples/fabric/reinforcement_learning/train_fabric.py b/examples/fabric/reinforcement_learning/train_fabric.py index 068359602a096..4df52d7cd0455 100644 --- a/examples/fabric/reinforcement_learning/train_fabric.py +++ b/examples/fabric/reinforcement_learning/train_fabric.py @@ -21,7 +21,6 @@ import os import time from datetime import datetime -from typing import Dict import gymnasium as gym import torch @@ -38,7 +37,7 @@ def train( fabric: Fabric, agent: PPOLightningAgent, optimizer: torch.optim.Optimizer, - data: Dict[str, Tensor], + data: dict[str, Tensor], global_step: int, args: argparse.Namespace, ): diff --git a/examples/fabric/reinforcement_learning/train_torch.py b/examples/fabric/reinforcement_learning/train_torch.py index cf74e03f5202e..dad16ad10fb0b 100644 --- a/examples/fabric/reinforcement_learning/train_torch.py +++ b/examples/fabric/reinforcement_learning/train_torch.py @@ -22,7 +22,6 @@ import random import time from datetime import datetime -from typing import Dict import gymnasium as gym import torch @@ -41,7 +40,7 @@ def train( agent: PPOAgent, optimizer: torch.optim.Optimizer, - data: Dict[str, Tensor], + data: dict[str, Tensor], logger: SummaryWriter, global_step: int, args: argparse.Namespace, diff --git a/examples/fabric/tensor_parallel/model.py b/examples/fabric/tensor_parallel/model.py index 3c9e7de472b90..71f2634867e9b 100644 --- a/examples/fabric/tensor_parallel/model.py +++ b/examples/fabric/tensor_parallel/model.py @@ -9,7 +9,7 @@ from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional import torch import torch.nn.functional as F @@ -87,7 +87,7 @@ def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided diff --git a/examples/pytorch/basics/autoencoder.py b/examples/pytorch/basics/autoencoder.py index 260cce4a548b1..d6f594b12f57b 100644 --- a/examples/pytorch/basics/autoencoder.py +++ b/examples/pytorch/basics/autoencoder.py @@ -18,7 +18,7 @@ """ from os import path -from typing import Optional, Tuple +from typing import Optional import torch import torch.nn.functional as F @@ -45,7 +45,7 @@ def __init__( nrow: int = 8, padding: int = 2, normalize: bool = True, - value_range: Optional[Tuple[int, int]] = None, + value_range: Optional[tuple[int, int]] = None, scale_each: bool = False, pad_value: int = 0, ) -> None: diff --git a/examples/pytorch/domain_templates/reinforce_learn_Qnet.py b/examples/pytorch/domain_templates/reinforce_learn_Qnet.py index 497cb658c275f..b3bfaaea93e7f 100644 --- a/examples/pytorch/domain_templates/reinforce_learn_Qnet.py +++ b/examples/pytorch/domain_templates/reinforce_learn_Qnet.py @@ -35,7 +35,7 @@ import argparse import random from collections import OrderedDict, deque, namedtuple -from typing import Iterator, List, Tuple +from collections.abc import Iterator import gym import torch @@ -102,7 +102,7 @@ def append(self, experience: Experience) -> None: """ self.buffer.append(experience) - def sample(self, batch_size: int) -> Tuple: + def sample(self, batch_size: int) -> tuple: indices = random.sample(range(len(self.buffer)), batch_size) states, actions, rewards, dones, next_states = zip(*(self.buffer[idx] for idx in indices)) @@ -190,7 +190,7 @@ def get_action(self, net: nn.Module, epsilon: float, device: str) -> int: return action @torch.no_grad() - def play_step(self, net: nn.Module, epsilon: float = 0.0, device: str = "cpu") -> Tuple[float, bool]: + def play_step(self, net: nn.Module, epsilon: float = 0.0, device: str = "cpu") -> tuple[float, bool]: """Carries out a single interaction step between the agent and the environment. Args: @@ -295,7 +295,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ return self.net(x) - def dqn_mse_loss(self, batch: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + def dqn_mse_loss(self, batch: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: """Calculates the mse loss using a mini batch from the replay buffer. Args: @@ -318,7 +318,7 @@ def dqn_mse_loss(self, batch: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor return nn.MSELoss()(state_action_values, expected_state_action_values) - def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], nb_batch) -> OrderedDict: + def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], nb_batch) -> OrderedDict: """Carries out a single step through the environment to update the replay buffer. Then calculates loss based on the minibatch received. @@ -356,7 +356,7 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], nb_batch) -> O return OrderedDict({"loss": loss, "log": log, "progress_bar": log}) - def configure_optimizers(self) -> List[Optimizer]: + def configure_optimizers(self) -> list[Optimizer]: """Initialize Adam optimizer.""" optimizer = optim.Adam(self.net.parameters(), lr=self.lr) return [optimizer] diff --git a/examples/pytorch/domain_templates/reinforce_learn_ppo.py b/examples/pytorch/domain_templates/reinforce_learn_ppo.py index bc3f8c1b9b193..1fb083894c284 100644 --- a/examples/pytorch/domain_templates/reinforce_learn_ppo.py +++ b/examples/pytorch/domain_templates/reinforce_learn_ppo.py @@ -30,7 +30,8 @@ """ import argparse -from typing import Callable, Iterator, List, Tuple +from collections.abc import Iterator +from typing import Callable import gym import torch @@ -41,7 +42,7 @@ from torch.utils.data import DataLoader, IterableDataset -def create_mlp(input_shape: Tuple[int], n_actions: int, hidden_size: int = 128): +def create_mlp(input_shape: tuple[int], n_actions: int, hidden_size: int = 128): """Simple Multi-Layer Perceptron network.""" return nn.Sequential( nn.Linear(input_shape[0], hidden_size), @@ -227,7 +228,7 @@ def __init__( self.state = torch.FloatTensor(self.env.reset()) - def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Passes in a state x through the network and returns the policy and a sampled action. Args: @@ -242,7 +243,7 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Te return pi, action, value - def discount_rewards(self, rewards: List[float], discount: float) -> List[float]: + def discount_rewards(self, rewards: list[float], discount: float) -> list[float]: """Calculate the discounted rewards of all rewards in list. Args: @@ -263,7 +264,7 @@ def discount_rewards(self, rewards: List[float], discount: float) -> List[float] return list(reversed(cumul_reward)) - def calc_advantage(self, rewards: List[float], values: List[float], last_value: float) -> List[float]: + def calc_advantage(self, rewards: list[float], values: list[float], last_value: float) -> list[float]: """Calculate the advantage given rewards, state values, and the last value of episode. Args: @@ -281,7 +282,7 @@ def calc_advantage(self, rewards: List[float], values: List[float], last_value: delta = [rews[i] + self.gamma * vals[i + 1] - vals[i] for i in range(len(rews) - 1)] return self.discount_rewards(delta, self.gamma * self.lam) - def generate_trajectory_samples(self) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: + def generate_trajectory_samples(self) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: """ Contains the logic for generating trajectory data to train policy and value network Yield: @@ -375,7 +376,7 @@ def critic_loss(self, state, action, logp_old, qval, adv) -> torch.Tensor: value = self.critic(state) return (qval - value).pow(2).mean() - def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor]): + def training_step(self, batch: tuple[torch.Tensor, torch.Tensor]): """Carries out a single update to actor and critic network from a batch of replay buffer. Args: @@ -406,7 +407,7 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor]): self.log("loss_critic", loss_critic, on_step=False, on_epoch=True, prog_bar=False, logger=True) self.log("loss_actor", loss_actor, on_step=False, on_epoch=True, prog_bar=True, logger=True) - def configure_optimizers(self) -> List[Optimizer]: + def configure_optimizers(self) -> list[Optimizer]: """Initialize Adam optimizer.""" optimizer_actor = torch.optim.Adam(self.actor.parameters(), lr=self.lr_actor) optimizer_critic = torch.optim.Adam(self.critic.parameters(), lr=self.lr_critic) diff --git a/examples/pytorch/fp8_distributed_transformer/README.md b/examples/pytorch/fp8_distributed_transformer/README.md new file mode 100644 index 0000000000000..6c5e12d14da4a --- /dev/null +++ b/examples/pytorch/fp8_distributed_transformer/README.md @@ -0,0 +1,39 @@ +## Distributed, Low-Precision Transformer Example + +This example shows how to use `ModelParallelStrategy` in `Fabric` to train a Transformer model minimizing memory usage, maximizing throughput, and distributing load across multiple GPUs. + +### Training Large Models and Memory Requirements + +One of the main challenges when training large models, like large language models (LLMs), is dealing with their memory footprint. LLMs can be so large that weights, activations, gradients and optimizer state don't fit a single GPU, so that they need to be distributed across multiple GPUs, and across multiple machines. There are multiple ways of distributing computations, among which fully-sharded data parallelism (FSDP) and tensor parallelism (TP). + +An additional way of reducing memory requirements is representing floating point numbers in weights and activations in low numerical precision, such as 16-bit (`bfloat16`), or 8-bit (`fp8`). This leads to savings in memory usage, as well as memory bandwidth usage (fewer bytes transferred from device memory to GPU cores in unit time). + +Roughly, reducing precision to `fp8` for linear layers can lead to 2x reduction in memory requirements and 1.6x improvement in throughput. Support for `fp8` weights and activations requires recent GPUs - Hopper, Ada Lovelace and above (e.g. H100, L4, L40). + +The introduction of tensor subclasses in PyTorch brought two new APIs that can be used to achieve memory savings and distributed training (as well as inference) in combination: + +- [torch ao](https://github.com/pytorch/ao) to execute linear layers in low numerical precision (`fp8` and other quantized formats) +- [dtensors](https://pytorch.org/docs/stable/distributed.tensor.html) to distribute models across GPUs, by combining TP and FSDP (referred to FSDP2 in PyTorch) + +Notably, `torch ao` introduces quantization and dequantization operations in the model that may result in slow-downs if not optimized. Using `torch.compile` after `torch ao` recovers performance by generating optimized kernels for those operations. + +### Vanilla Transformer Example + +This example shows how to train a vanilla Transformer model using `fp8` precision and the FSDP2 distributed strategy, and then optimize the resulting model through `torch.compile`. + +Specifically, we employ the `ModelParallelStrategy`, which accepts a `parallelize_fn` to distribute the model using the PyTorch DTensor API. +We use the same function to also pass the model through the `torch ao` API (prior to FSDP2), as well as `torch.compile` (after FSDP2). + +The resulting code follows the PyTorch API closely, while also taking advantage of the rest of Lightning Fabric. + +To execute the code directly just run: + +```bash +python train.py +``` + +### A Note on torch.compile + +Note that Fabric also supports calling `torch.compile` on a model and passing it to `fabric.setup_model` or `fabric.setup_model_and_optimizers`. + +While this works well, in order to get the most out of the combination of the latest distributed, quantization, and compile PyTorch API's, we recommend invoking `torch.compile` as part of the `parallelize_fn` argument of `ModelParallelStrategy`, as shown in this example. diff --git a/examples/pytorch/fp8_distributed_transformer/requirements.txt b/examples/pytorch/fp8_distributed_transformer/requirements.txt new file mode 100644 index 0000000000000..ce00e191aa9c1 --- /dev/null +++ b/examples/pytorch/fp8_distributed_transformer/requirements.txt @@ -0,0 +1 @@ +torchao>=0.7.0 diff --git a/examples/pytorch/fp8_distributed_transformer/train.py b/examples/pytorch/fp8_distributed_transformer/train.py new file mode 100644 index 0000000000000..6c7be98ee7dbd --- /dev/null +++ b/examples/pytorch/fp8_distributed_transformer/train.py @@ -0,0 +1,85 @@ +import lightning as L +import torch +import torch.nn as nn +import torch.nn.functional as F +from lightning.pytorch.demos import Transformer, WikiText2 +from lightning.pytorch.strategies import ModelParallelStrategy +from torch.distributed._composable.fsdp.fully_shard import fully_shard +from torch.utils.data import DataLoader +from torchao.float8 import Float8LinearConfig, convert_to_float8_training + + +class LanguageModel(L.LightningModule): + def __init__(self, vocab_size): + super().__init__() + self.vocab_size = vocab_size + self.model = None + + def configure_model(self): + if self.model is not None: + return + + with torch.device("meta"): + model = Transformer( + vocab_size=self.vocab_size, + nlayers=16, + nhid=4096, + ninp=1024, + nhead=32, + ) + + float8_config = Float8LinearConfig( + # pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly # noqa + pad_inner_dim=True, + ) + + def module_filter_fn(mod: torch.nn.Module, fqn: str): + # we skip the decoder because it typically vocabulary size + # is not divisible by 16 as required by float8 + return fqn != "decoder" + + convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn) + + for module in model.modules(): + if isinstance(module, (nn.TransformerEncoderLayer, nn.TransformerDecoderLayer)): + fully_shard(module, mesh=self.device_mesh) + + fully_shard(model, mesh=self.device_mesh) + + self.model = torch.compile(model) + + def training_step(self, batch): + input, target = batch + output = self.model(input, target) + loss = F.nll_loss(output, target.view(-1)) + self.log("train_loss", loss, prog_bar=True) + return loss + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=1e-4) + + +def train(): + L.seed_everything(42) + + dataset = WikiText2() + train_dataloader = DataLoader(dataset, num_workers=8, batch_size=1) + + model = LanguageModel(vocab_size=dataset.vocab_size) + + mp_strategy = ModelParallelStrategy( + data_parallel_size=4, + tensor_parallel_size=1, + ) + + trainer = L.Trainer(strategy=mp_strategy, max_steps=100, precision="bf16-true", accumulate_grad_batches=8) + + trainer.fit(model, train_dataloader) + + trainer.print(torch.cuda.memory_summary()) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("high") + + train() diff --git a/examples/pytorch/servable_module/production.py b/examples/pytorch/servable_module/production.py index f1d5a06e3c584..da0c42d12a865 100644 --- a/examples/pytorch/servable_module/production.py +++ b/examples/pytorch/servable_module/production.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from io import BytesIO from os import path -from typing import Dict, Optional +from typing import Optional import numpy as np import torch @@ -93,7 +93,7 @@ def configure_payload(self): def configure_serialization(self): return {"x": Image(224, 224).deserialize}, {"output": Top1().serialize} - def serve_step(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: + def serve_step(self, x: torch.Tensor) -> dict[str, torch.Tensor]: return {"output": self.model(x)} def configure_response(self): diff --git a/examples/pytorch/tensor_parallel/model.py b/examples/pytorch/tensor_parallel/model.py index 3c9e7de472b90..71f2634867e9b 100644 --- a/examples/pytorch/tensor_parallel/model.py +++ b/examples/pytorch/tensor_parallel/model.py @@ -9,7 +9,7 @@ from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional import torch import torch.nn.functional as F @@ -87,7 +87,7 @@ def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided diff --git a/pyproject.toml b/pyproject.toml index 6edd6d1a8f11f..48439bee75332 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ ignore-words-list = "te, compiletime" [tool.ruff] line-length = 120 -target-version = "py38" +target-version = "py39" # Exclude a variety of commonly ignored directories. exclude = [ ".git", @@ -76,7 +76,6 @@ ignore = [ "S108", "E203", # conflicts with black ] -ignore-init-module-imports = true [tool.ruff.lint.per-file-ignores] ".actions/*" = ["S101", "S310"] diff --git a/requirements/fabric/base.txt b/requirements/fabric/base.txt index 0a99614a46870..42c055e85ca7d 100644 --- a/requirements/fabric/base.txt +++ b/requirements/fabric/base.txt @@ -1,7 +1,7 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -torch >=2.1.0, <2.5.0 +torch >=2.1.0, <2.6.0 fsspec[http] >=2022.5.0, <2024.4.0 packaging >=20.0, <=23.1 typing-extensions >=4.4.0, <4.10.0 diff --git a/requirements/fabric/examples.txt b/requirements/fabric/examples.txt index cb4135da2409a..3352db77d8bd9 100644 --- a/requirements/fabric/examples.txt +++ b/requirements/fabric/examples.txt @@ -1,6 +1,6 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -torchvision >=0.16.0, <0.20.0 -torchmetrics >=0.10.0, <1.3.0 +torchvision >=0.16.0, <0.21.0 +torchmetrics >=0.10.0, <1.5.0 lightning-utilities >=0.8.0, <0.12.0 diff --git a/requirements/fabric/strategies.txt b/requirements/fabric/strategies.txt index 4aee89d9f68e7..394aceb39cd6b 100644 --- a/requirements/fabric/strategies.txt +++ b/requirements/fabric/strategies.txt @@ -6,4 +6,5 @@ # note: is a bug around 0.10 with `MPS_Accelerator must implement all abstract methods` # shall be resolved by https://github.com/microsoft/DeepSpeed/issues/4372 deepspeed >=0.8.2, <=0.9.3; platform_system != "Windows" and platform_system != "Darwin" # strict -bitsandbytes >=0.42.0,<0.43.0 +bitsandbytes >=0.44.0,<0.44.2; sys_platform == 'linux' or sys_platform == 'win32' +bitsandbytes >=0.42.0,<0.43.0 ; sys_platform == 'darwin' diff --git a/requirements/fabric/test.txt b/requirements/fabric/test.txt index 8fb9122051eec..2da6ae8854d64 100644 --- a/requirements/fabric/test.txt +++ b/requirements/fabric/test.txt @@ -7,4 +7,4 @@ pytest-rerunfailures ==12.0 pytest-random-order ==1.1.0 click ==8.1.7 tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute -torchmetrics >=0.7.0, <1.3.0 # needed for using fixed compare_version +torchmetrics >=0.7.0, <1.5.0 # needed for using fixed compare_version diff --git a/requirements/pytorch/base.txt b/requirements/pytorch/base.txt index 6ff628d7edfb5..94aca759c37e2 100644 --- a/requirements/pytorch/base.txt +++ b/requirements/pytorch/base.txt @@ -1,11 +1,11 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -torch >=2.1.0, <2.5.0 +torch >=2.1.0, <2.6.0 tqdm >=4.57.0, <4.67.0 PyYAML >=5.4, <6.1.0 fsspec[http] >=2022.5.0, <2024.4.0 -torchmetrics >=0.7.0, <1.3.0 # needed for using fixed compare_version +torchmetrics >=0.7.0, <1.5.0 # needed for using fixed compare_version packaging >=20.0, <=23.1 typing-extensions >=4.4.0, <4.10.0 lightning-utilities >=0.10.0, <0.12.0 diff --git a/requirements/pytorch/examples.txt b/requirements/pytorch/examples.txt index 9a6ae7e47dfb8..2e793e0045da9 100644 --- a/requirements/pytorch/examples.txt +++ b/requirements/pytorch/examples.txt @@ -2,7 +2,7 @@ # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment requests <2.32.0 -torchvision >=0.16.0, <0.20.0 +torchvision >=0.16.0, <0.21.0 ipython[all] <8.15.0 -torchmetrics >=0.10.0, <1.3.0 +torchmetrics >=0.10.0, <1.5.0 lightning-utilities >=0.8.0, <0.12.0 diff --git a/requirements/pytorch/extra.txt b/requirements/pytorch/extra.txt index 6962da858c4ab..12bbdf5a70ab0 100644 --- a/requirements/pytorch/extra.txt +++ b/requirements/pytorch/extra.txt @@ -8,4 +8,5 @@ hydra-core >=1.2.0, <1.4.0 jsonargparse[signatures] >=4.27.7, <4.28.0 rich >=12.3.0, <13.6.0 tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute -bitsandbytes >=0.42.0,<0.43.0 +bitsandbytes >=0.44.0,<0.44.2; sys_platform == 'linux' or sys_platform == 'win32' +bitsandbytes >=0.42.0,<0.43.0 ; sys_platform == 'darwin' diff --git a/requirements/typing.txt b/requirements/typing.txt index 9f1952605babc..71414998dd7f3 100644 --- a/requirements/typing.txt +++ b/requirements/typing.txt @@ -1,5 +1,5 @@ mypy==1.11.0 -torch==2.4.0 +torch==2.5.1 types-Markdown types-PyYAML diff --git a/setup.py b/setup.py index bfc329bb8fe88..92f0265eafb9f 100755 --- a/setup.py +++ b/setup.py @@ -45,9 +45,10 @@ import logging import os import tempfile +from collections.abc import Generator, Mapping from importlib.util import module_from_spec, spec_from_file_location from types import ModuleType -from typing import Generator, Mapping, Optional +from typing import Optional import setuptools import setuptools.command.egg_info diff --git a/src/lightning/__setup__.py b/src/lightning/__setup__.py index 09eab5601f443..2d3bb0e7d1f33 100644 --- a/src/lightning/__setup__.py +++ b/src/lightning/__setup__.py @@ -3,7 +3,7 @@ from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path from types import ModuleType -from typing import Any, Dict +from typing import Any from setuptools import find_namespace_packages @@ -26,7 +26,7 @@ def _load_py_module(name: str, location: str) -> ModuleType: _ASSISTANT = _load_py_module(name="assistant", location=os.path.join(_PROJECT_ROOT, ".actions", "assistant.py")) -def _prepare_extras() -> Dict[str, Any]: +def _prepare_extras() -> dict[str, Any]: # https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-extras # Define package extras. These are only installed if you specify them. # From remote, use like `pip install "lightning[dev, docs]"` @@ -63,7 +63,7 @@ def _prepare_extras() -> Dict[str, Any]: return extras -def _setup_args() -> Dict[str, Any]: +def _setup_args() -> dict[str, Any]: about = _load_py_module("about", os.path.join(_PACKAGE_ROOT, "__about__.py")) version = _load_py_module("version", os.path.join(_PACKAGE_ROOT, "__version__.py")) long_description = _ASSISTANT.load_readme_description( diff --git a/src/lightning/fabric/__init__.py b/src/lightning/fabric/__init__.py index 921d3d61e60fe..d675b21e5d1d2 100644 --- a/src/lightning/fabric/__init__.py +++ b/src/lightning/fabric/__init__.py @@ -2,6 +2,7 @@ import logging import os +import sys from lightning_utilities.core.imports import package_available @@ -26,6 +27,10 @@ # https://github.com/pytorch/pytorch/issues/83973 os.environ["PYTORCH_NVML_BASED_CUDA_CHECK"] = "1" +# see https://github.com/pytorch/pytorch/issues/139990 +if sys.platform == "win32": + os.environ["USE_LIBUV"] = "0" + from lightning.fabric.fabric import Fabric # noqa: E402 from lightning.fabric.utilities.seed import seed_everything # noqa: E402 diff --git a/src/lightning/fabric/accelerators/cpu.py b/src/lightning/fabric/accelerators/cpu.py index 1bcec1b2ac278..2997d1ada3352 100644 --- a/src/lightning/fabric/accelerators/cpu.py +++ b/src/lightning/fabric/accelerators/cpu.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union +from typing import Union import torch from typing_extensions import override @@ -39,13 +39,13 @@ def teardown(self) -> None: @staticmethod @override - def parse_devices(devices: Union[int, str, List[int]]) -> int: + def parse_devices(devices: Union[int, str]) -> int: """Accelerator device parsing logic.""" return _parse_cpu_cores(devices) @staticmethod @override - def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]: + def get_parallel_devices(devices: Union[int, str]) -> list[torch.device]: """Gets parallel devices for the Accelerator.""" devices = _parse_cpu_cores(devices) return [torch.device("cpu")] * devices @@ -72,12 +72,12 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No ) -def _parse_cpu_cores(cpu_cores: Union[int, str, List[int]]) -> int: +def _parse_cpu_cores(cpu_cores: Union[int, str]) -> int: """Parses the cpu_cores given in the format as accepted by the ``devices`` argument in the :class:`~lightning.pytorch.trainer.trainer.Trainer`. Args: - cpu_cores: An int > 0. + cpu_cores: An int > 0 or a string that can be converted to an int > 0. Returns: An int representing the number of processes diff --git a/src/lightning/fabric/accelerators/cuda.py b/src/lightning/fabric/accelerators/cuda.py index 4afc9be723fc2..5b8a4c2f80bed 100644 --- a/src/lightning/fabric/accelerators/cuda.py +++ b/src/lightning/fabric/accelerators/cuda.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import lru_cache -from typing import List, Optional, Union +from typing import Optional, Union import torch from typing_extensions import override @@ -43,7 +43,7 @@ def teardown(self) -> None: @staticmethod @override - def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: + def parse_devices(devices: Union[int, str, list[int]]) -> Optional[list[int]]: """Accelerator device parsing logic.""" from lightning.fabric.utilities.device_parser import _parse_gpu_ids @@ -51,7 +51,7 @@ def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: @staticmethod @override - def get_parallel_devices(devices: List[int]) -> List[torch.device]: + def get_parallel_devices(devices: list[int]) -> list[torch.device]: """Gets parallel devices for the Accelerator.""" return [torch.device("cuda", i) for i in devices] @@ -76,7 +76,7 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No ) -def find_usable_cuda_devices(num_devices: int = -1) -> List[int]: +def find_usable_cuda_devices(num_devices: int = -1) -> list[int]: """Returns a list of all available and usable CUDA GPU devices. A GPU is considered usable if we can successfully move a tensor to the device, and this is what this function @@ -129,7 +129,7 @@ def find_usable_cuda_devices(num_devices: int = -1) -> List[int]: return available_devices -def _get_all_visible_cuda_devices() -> List[int]: +def _get_all_visible_cuda_devices() -> list[int]: """Returns a list of all visible CUDA GPU devices. Devices masked by the environment variabale ``CUDA_VISIBLE_DEVICES`` won't be returned here. For example, assume you diff --git a/src/lightning/fabric/accelerators/mps.py b/src/lightning/fabric/accelerators/mps.py index 75497169cda0f..b535ba57ed4cb 100644 --- a/src/lightning/fabric/accelerators/mps.py +++ b/src/lightning/fabric/accelerators/mps.py @@ -14,7 +14,7 @@ import os import platform from functools import lru_cache -from typing import List, Optional, Union +from typing import Optional, Union import torch from typing_extensions import override @@ -46,7 +46,7 @@ def teardown(self) -> None: @staticmethod @override - def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: + def parse_devices(devices: Union[int, str, list[int]]) -> Optional[list[int]]: """Accelerator device parsing logic.""" from lightning.fabric.utilities.device_parser import _parse_gpu_ids @@ -54,7 +54,7 @@ def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: @staticmethod @override - def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]: + def get_parallel_devices(devices: Union[int, str, list[int]]) -> list[torch.device]: """Gets parallel devices for the Accelerator.""" parsed_devices = MPSAccelerator.parse_devices(devices) assert parsed_devices is not None @@ -84,7 +84,7 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No ) -def _get_all_available_mps_gpus() -> List[int]: +def _get_all_available_mps_gpus() -> list[int]: """ Returns: A list of all available MPS GPUs diff --git a/src/lightning/fabric/accelerators/registry.py b/src/lightning/fabric/accelerators/registry.py index 1299b1e148aa8..17d5233336d50 100644 --- a/src/lightning/fabric/accelerators/registry.py +++ b/src/lightning/fabric/accelerators/registry.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Optional from typing_extensions import override @@ -68,7 +68,7 @@ def register( if name in self and not override: raise MisconfigurationException(f"'{name}' is already present in the registry. HINT: Use `override=True`.") - data: Dict[str, Any] = {} + data: dict[str, Any] = {} data["description"] = description data["init_params"] = init_params @@ -107,7 +107,7 @@ def remove(self, name: str) -> None: """Removes the registered accelerator by name.""" self.pop(name) - def available_accelerators(self) -> List[str]: + def available_accelerators(self) -> list[str]: """Returns a list of registered accelerators.""" return list(self.keys()) diff --git a/src/lightning/fabric/accelerators/xla.py b/src/lightning/fabric/accelerators/xla.py index 38d7380dc7905..d438197329939 100644 --- a/src/lightning/fabric/accelerators/xla.py +++ b/src/lightning/fabric/accelerators/xla.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -from typing import Any, List, Union +from typing import Any, Union import torch from lightning_utilities.core.imports import RequirementCache @@ -47,13 +47,13 @@ def teardown(self) -> None: @staticmethod @override - def parse_devices(devices: Union[int, str, List[int]]) -> Union[int, List[int]]: + def parse_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]: """Accelerator device parsing logic.""" return _parse_tpu_devices(devices) @staticmethod @override - def get_parallel_devices(devices: Union[int, List[int]]) -> List[torch.device]: + def get_parallel_devices(devices: Union[int, list[int]]) -> list[torch.device]: """Gets parallel devices for the Accelerator.""" devices = _parse_tpu_devices(devices) if isinstance(devices, int): @@ -102,20 +102,27 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No # PJRT support requires this minimum version _XLA_AVAILABLE = RequirementCache("torch_xla>=1.13", "torch_xla") _XLA_GREATER_EQUAL_2_1 = RequirementCache("torch_xla>=2.1") +_XLA_GREATER_EQUAL_2_5 = RequirementCache("torch_xla>=2.5") def _using_pjrt() -> bool: + # `using_pjrt` is removed in torch_xla 2.5 + if _XLA_GREATER_EQUAL_2_5: + from torch_xla import runtime as xr + + return xr.device_type() is not None # delete me when torch_xla 2.2 is the min supported version, where XRT support has been dropped. if _XLA_GREATER_EQUAL_2_1: from torch_xla import runtime as xr return xr.using_pjrt() + from torch_xla.experimental import pjrt return pjrt.using_pjrt() -def _parse_tpu_devices(devices: Union[int, str, List[int]]) -> Union[int, List[int]]: +def _parse_tpu_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]: """Parses the TPU devices given in the format as accepted by the :class:`~lightning.pytorch.trainer.trainer.Trainer` and :class:`~lightning.fabric.Fabric`. @@ -152,7 +159,7 @@ def _check_tpu_devices_valid(devices: object) -> None: ) -def _parse_tpu_devices_str(devices: str) -> Union[int, List[int]]: +def _parse_tpu_devices_str(devices: str) -> Union[int, list[int]]: devices = devices.strip() try: return int(devices) diff --git a/src/lightning/fabric/cli.py b/src/lightning/fabric/cli.py index 5ca46ba331622..5f18884e83d79 100644 --- a/src/lightning/fabric/cli.py +++ b/src/lightning/fabric/cli.py @@ -17,7 +17,7 @@ import subprocess import sys from argparse import Namespace -from typing import Any, List, Optional +from typing import Any, Optional import torch from lightning_utilities.core.imports import RequirementCache @@ -39,7 +39,7 @@ _SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu") -def _get_supported_strategies() -> List[str]: +def _get_supported_strategies() -> list[str]: """Returns strategy choices from the registry, with the ones removed that are incompatible to be launched from the CLI or ones that require further configuration by the user.""" available_strategies = STRATEGY_REGISTRY.available_strategies() @@ -140,7 +140,7 @@ def _main() -> None: type=click.Choice(get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_STR_ALIAS)), default=None, help=( - "Double precision (``64-true`` or ``64``), full precision (``32-true`` or ``64``), " + "Double precision (``64-true`` or ``64``), full precision (``32-true`` or ``32``), " "half precision (``16-mixed`` or ``16``) or bfloat16 precision (``bf16-mixed`` or ``bf16``)" ), ) @@ -221,7 +221,7 @@ def _get_num_processes(accelerator: str, devices: str) -> int: return len(parsed_devices) if parsed_devices is not None else 0 -def _torchrun_launch(args: Namespace, script_args: List[str]) -> None: +def _torchrun_launch(args: Namespace, script_args: list[str]) -> None: """This will invoke `torchrun` programmatically to launch the given script in new processes.""" import torch.distributed.run as torchrun @@ -242,7 +242,7 @@ def _torchrun_launch(args: Namespace, script_args: List[str]) -> None: torchrun.main(torchrun_args) -def main(args: Namespace, script_args: Optional[List[str]] = None) -> None: +def main(args: Namespace, script_args: Optional[list[str]] = None) -> None: _set_env_variables(args) _torchrun_launch(args, script_args or []) diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index 9fb66255830c6..0ade7f69c3629 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -13,7 +13,8 @@ # limitations under the License. import os from collections import Counter -from typing import Any, Dict, List, Optional, Union, cast +from collections.abc import Iterable +from typing import Any, Optional, Union, cast import torch from typing_extensions import get_args @@ -99,10 +100,10 @@ def __init__( self, accelerator: Union[str, Accelerator] = "auto", strategy: Union[str, Strategy] = "auto", - devices: Union[List[int], str, int] = "auto", + devices: Union[list[int], str, int] = "auto", num_nodes: int = 1, precision: Optional[_PRECISION_INPUT] = None, - plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None, + plugins: Optional[Union[_PLUGIN_INPUT, Iterable[_PLUGIN_INPUT]]] = None, ) -> None: # These arguments can be set through environment variables set by the CLI accelerator = self._argument_from_env("accelerator", accelerator, default="auto") @@ -124,7 +125,7 @@ def __init__( self._precision_input: _PRECISION_INPUT_STR = "32-true" self._precision_instance: Optional[Precision] = None self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None - self._parallel_devices: List[Union[int, torch.device, str]] = [] + self._parallel_devices: list[Union[int, torch.device, str]] = [] self.checkpoint_io: Optional[CheckpointIO] = None self._check_config_and_set_final_flags( @@ -165,7 +166,7 @@ def _check_config_and_set_final_flags( strategy: Union[str, Strategy], accelerator: Union[str, Accelerator], precision: Optional[_PRECISION_INPUT], - plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]], + plugins: Optional[Union[_PLUGIN_INPUT, Iterable[_PLUGIN_INPUT]]], ) -> None: """This method checks: @@ -180,7 +181,7 @@ def _check_config_and_set_final_flags( """ if plugins is not None: - plugins = [plugins] if not isinstance(plugins, list) else plugins + plugins = [plugins] if not isinstance(plugins, Iterable) else plugins if isinstance(strategy, str): strategy = strategy.lower() @@ -224,7 +225,7 @@ def _check_config_and_set_final_flags( precision_input = _convert_precision_to_unified_args(precision) if plugins: - plugins_flags_types: Dict[str, int] = Counter() + plugins_flags_types: dict[str, int] = Counter() for plugin in plugins: if isinstance(plugin, Precision): self._precision_instance = plugin @@ -295,7 +296,7 @@ def _check_config_and_set_final_flags( self._accelerator_flag = "cuda" self._parallel_devices = self._strategy_flag.parallel_devices - def _check_device_config_and_set_final_flags(self, devices: Union[List[int], str, int], num_nodes: int) -> None: + def _check_device_config_and_set_final_flags(self, devices: Union[list[int], str, int], num_nodes: int) -> None: if not isinstance(num_nodes, int) or num_nodes < 1: raise ValueError(f"`num_nodes` must be a positive integer, but got {num_nodes}.") diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index 0ff5b04b30b0a..058e5e7c40751 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -13,20 +13,14 @@ # limitations under the License. import inspect import os -from contextlib import contextmanager, nullcontext +from collections.abc import Generator, Mapping, Sequence +from contextlib import AbstractContextManager, contextmanager, nullcontext from functools import partial from pathlib import Path from typing import ( Any, Callable, - ContextManager, - Dict, - Generator, - List, - Mapping, Optional, - Sequence, - Tuple, Union, cast, overload, @@ -118,12 +112,12 @@ def __init__( *, accelerator: Union[str, Accelerator] = "auto", strategy: Union[str, Strategy] = "auto", - devices: Union[List[int], str, int] = "auto", + devices: Union[list[int], str, int] = "auto", num_nodes: int = 1, precision: Optional[_PRECISION_INPUT] = None, - plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None, - callbacks: Optional[Union[List[Any], Any]] = None, - loggers: Optional[Union[Logger, List[Logger]]] = None, + plugins: Optional[Union[_PLUGIN_INPUT, list[_PLUGIN_INPUT]]] = None, + callbacks: Optional[Union[list[Any], Any]] = None, + loggers: Optional[Union[Logger, list[Logger]]] = None, ) -> None: self._connector = _Connector( accelerator=accelerator, @@ -192,7 +186,7 @@ def is_global_zero(self) -> bool: return self._strategy.is_global_zero @property - def loggers(self) -> List[Logger]: + def loggers(self) -> list[Logger]: """Returns all loggers passed to Fabric.""" return self._loggers @@ -326,7 +320,7 @@ def setup_module( self._models_setup += 1 return module - def setup_optimizers(self, *optimizers: Optimizer) -> Union[_FabricOptimizer, Tuple[_FabricOptimizer, ...]]: + def setup_optimizers(self, *optimizers: Optimizer) -> Union[_FabricOptimizer, tuple[_FabricOptimizer, ...]]: r"""Set up one or more optimizers for accelerated training. Some strategies do not allow setting up model and optimizer independently. For them, you should call @@ -349,7 +343,7 @@ def setup_optimizers(self, *optimizers: Optimizer) -> Union[_FabricOptimizer, Tu def setup_dataloaders( self, *dataloaders: DataLoader, use_distributed_sampler: bool = True, move_to_device: bool = True - ) -> Union[DataLoader, List[DataLoader]]: + ) -> Union[DataLoader, list[DataLoader]]: r"""Set up one or multiple dataloaders for accelerated training. If you need different settings for each dataloader, call this method individually for each one. @@ -489,7 +483,7 @@ def clip_gradients( ) raise ValueError("You have to specify either `clip_val` or `max_norm` to do gradient clipping!") - def autocast(self) -> ContextManager: + def autocast(self) -> AbstractContextManager: """A context manager to automatically convert operations for the chosen precision. Use this only if the `forward` method of your model does not cover all operations you wish to run with the @@ -564,8 +558,8 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: return self._strategy.broadcast(obj, src=src) def all_gather( - self, data: Union[Tensor, Dict, List, Tuple], group: Optional[Any] = None, sync_grads: bool = False - ) -> Union[Tensor, Dict, List, Tuple]: + self, data: Union[Tensor, dict, list, tuple], group: Optional[Any] = None, sync_grads: bool = False + ) -> Union[Tensor, dict, list, tuple]: """Gather tensors or collections of tensors from multiple processes. This method needs to be called on all processes and the tensors need to have the same shape across all @@ -589,10 +583,10 @@ def all_gather( def all_reduce( self, - data: Union[Tensor, Dict, List, Tuple], + data: Union[Tensor, dict, list, tuple], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean", - ) -> Union[Tensor, Dict, List, Tuple]: + ) -> Union[Tensor, dict, list, tuple]: """Reduce tensors or collections of tensors from multiple processes. The reduction on tensors is applied in-place, meaning the result will be placed back into the input tensor. @@ -639,7 +633,7 @@ def rank_zero_first(self, local: bool = False) -> Generator: if rank == 0: barrier() - def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> ContextManager: + def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> AbstractContextManager: r"""Skip gradient synchronization during backward to avoid redundant communication overhead. Use this context manager when performing gradient accumulation to speed up training with multiple devices. @@ -681,7 +675,7 @@ def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> Conte forward_module, _ = _unwrap_compiled(module._forward_module) return self._strategy._backward_sync_control.no_backward_sync(forward_module, enabled) - def sharded_model(self) -> ContextManager: + def sharded_model(self) -> AbstractContextManager: r"""Instantiate a model under this context manager to prepare it for model-parallel sharding. .. deprecated:: This context manager is deprecated in favor of :meth:`init_module`, use it instead. @@ -693,12 +687,12 @@ def sharded_model(self) -> ContextManager: return self.strategy.module_sharded_context() return nullcontext() - def init_tensor(self) -> ContextManager: + def init_tensor(self) -> AbstractContextManager: """Tensors that you instantiate under this context manager will be created on the device right away and have the right data type depending on the precision setting in Fabric.""" return self._strategy.tensor_init_context() - def init_module(self, empty_init: Optional[bool] = None) -> ContextManager: + def init_module(self, empty_init: Optional[bool] = None) -> AbstractContextManager: """Instantiate the model and its parameters under this context manager to reduce peak memory usage. The parameters get created on the device and with the right data type right away without wasting memory being @@ -716,8 +710,8 @@ def init_module(self, empty_init: Optional[bool] = None) -> ContextManager: def save( self, path: Union[str, Path], - state: Dict[str, Union[nn.Module, Optimizer, Any]], - filter: Optional[Dict[str, Callable[[str, Any], bool]]] = None, + state: dict[str, Union[nn.Module, Optimizer, Any]], + filter: Optional[dict[str, Callable[[str, Any], bool]]] = None, ) -> None: r"""Save checkpoint contents to a file. @@ -750,9 +744,9 @@ def save( def load( self, path: Union[str, Path], - state: Optional[Dict[str, Union[nn.Module, Optimizer, Any]]] = None, + state: Optional[dict[str, Union[nn.Module, Optimizer, Any]]] = None, strict: bool = True, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Load a checkpoint from a file and restore the state of objects (modules, optimizers, etc.) How and which processes load gets determined by the `strategy`. @@ -933,7 +927,7 @@ def _wrap_with_setup(self, to_run: Callable, *args: Any, **kwargs: Any) -> Any: with _replace_dunder_methods(DataLoader, "dataset"), _replace_dunder_methods(BatchSampler): return to_run(*args, **kwargs) - def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module: + def _move_model_to_device(self, model: nn.Module, optimizers: list[Optimizer]) -> nn.Module: try: initial_name, initial_param = next(model.named_parameters()) except StopIteration: @@ -1061,7 +1055,7 @@ def _validate_setup_dataloaders(self, dataloaders: Sequence[DataLoader]) -> None raise TypeError("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.") @staticmethod - def _configure_callbacks(callbacks: Optional[Union[List[Any], Any]]) -> List[Any]: + def _configure_callbacks(callbacks: Optional[Union[list[Any], Any]]) -> list[Any]: callbacks = callbacks if callbacks is not None else [] callbacks = callbacks if isinstance(callbacks, list) else [callbacks] callbacks.extend(_load_external_callbacks("lightning.fabric.callbacks_factory")) diff --git a/src/lightning/fabric/loggers/csv_logs.py b/src/lightning/fabric/loggers/csv_logs.py index 4dbb56fa691db..dd7dfc63671f0 100644 --- a/src/lightning/fabric/loggers/csv_logs.py +++ b/src/lightning/fabric/loggers/csv_logs.py @@ -16,7 +16,7 @@ import logging import os from argparse import Namespace -from typing import Any, Dict, List, Optional, Set, Union +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import override @@ -138,13 +138,13 @@ def experiment(self) -> "_ExperimentWriter": @override @rank_zero_only - def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: + def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: raise NotImplementedError("The `CSVLogger` does not yet support logging hyperparameters.") @override @rank_zero_only def log_metrics( # type: ignore[override] - self, metrics: Dict[str, Union[Tensor, float]], step: Optional[int] = None + self, metrics: dict[str, Union[Tensor, float]], step: Optional[int] = None ) -> None: metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) if step is None: @@ -200,8 +200,8 @@ class _ExperimentWriter: NAME_METRICS_FILE = "metrics.csv" def __init__(self, log_dir: str) -> None: - self.metrics: List[Dict[str, float]] = [] - self.metrics_keys: List[str] = [] + self.metrics: list[dict[str, float]] = [] + self.metrics_keys: list[str] = [] self._fs = get_filesystem(log_dir) self.log_dir = log_dir @@ -210,7 +210,7 @@ def __init__(self, log_dir: str) -> None: self._check_log_dir_exists() self._fs.makedirs(self.log_dir, exist_ok=True) - def log_metrics(self, metrics_dict: Dict[str, float], step: Optional[int] = None) -> None: + def log_metrics(self, metrics_dict: dict[str, float], step: Optional[int] = None) -> None: """Record metrics.""" def _handle_value(value: Union[Tensor, Any]) -> Any: @@ -246,7 +246,7 @@ def save(self) -> None: self.metrics = [] # reset - def _record_new_keys(self) -> Set[str]: + def _record_new_keys(self) -> set[str]: """Records new keys that have not been logged before.""" current_keys = set().union(*self.metrics) new_keys = current_keys - set(self.metrics_keys) @@ -254,7 +254,7 @@ def _record_new_keys(self) -> Set[str]: self.metrics_keys.sort() return new_keys - def _rewrite_with_new_header(self, fieldnames: List[str]) -> None: + def _rewrite_with_new_header(self, fieldnames: list[str]) -> None: with self._fs.open(self.metrics_file_path, "r", newline="") as file: metrics = list(csv.DictReader(file)) diff --git a/src/lightning/fabric/loggers/logger.py b/src/lightning/fabric/loggers/logger.py index 5647ab9c1c7a2..39a9fa06a08d0 100644 --- a/src/lightning/fabric/loggers/logger.py +++ b/src/lightning/fabric/loggers/logger.py @@ -16,7 +16,7 @@ from abc import ABC, abstractmethod from argparse import Namespace from functools import wraps -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Optional, Union from torch import Tensor from torch.nn import Module @@ -55,7 +55,7 @@ def group_separator(self) -> str: return "/" @abstractmethod - def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: + def log_metrics(self, metrics: dict[str, float], step: Optional[int] = None) -> None: """Records metrics. This method logs metrics as soon as it received them. Args: @@ -66,7 +66,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> pass @abstractmethod - def log_hyperparams(self, params: Union[Dict[str, Any], Namespace], *args: Any, **kwargs: Any) -> None: + def log_hyperparams(self, params: Union[dict[str, Any], Namespace], *args: Any, **kwargs: Any) -> None: """Record hyperparameters. Args: diff --git a/src/lightning/fabric/loggers/tensorboard.py b/src/lightning/fabric/loggers/tensorboard.py index 685c832088818..208244dc38cd3 100644 --- a/src/lightning/fabric/loggers/tensorboard.py +++ b/src/lightning/fabric/loggers/tensorboard.py @@ -14,7 +14,8 @@ import os from argparse import Namespace -from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, Optional, Union from lightning_utilities.core.imports import RequirementCache from torch import Tensor @@ -219,15 +220,19 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) @override @rank_zero_only def log_hyperparams( - self, params: Union[Dict[str, Any], Namespace], metrics: Optional[Dict[str, Any]] = None + self, + params: Union[dict[str, Any], Namespace], + metrics: Optional[dict[str, Any]] = None, + step: Optional[int] = None, ) -> None: """Record hyperparameters. TensorBoard logs with and without saved hyperparameters are incompatible, the hyperparameters are then not displayed in the TensorBoard. Please delete or move the previously saved logs to display the new ones with hyperparameters. Args: - params: a dictionary-like container with the hyperparameters + params: A dictionary-like container with the hyperparameters metrics: Dictionary with metric names as keys and measured quantities as values + step: Optional global step number for the logged metrics """ params = _convert_params(params) @@ -243,7 +248,7 @@ def log_hyperparams( metrics = {"hp_metric": metrics} if metrics: - self.log_metrics(metrics, 0) + self.log_metrics(metrics, step) if _TENSORBOARD_AVAILABLE: from torch.utils.tensorboard.summary import hparams @@ -252,9 +257,9 @@ def log_hyperparams( exp, ssi, sei = hparams(params, metrics) writer = self.experiment._get_file_writer() - writer.add_summary(exp) - writer.add_summary(ssi) - writer.add_summary(sei) + writer.add_summary(exp, step) + writer.add_summary(ssi, step) + writer.add_summary(sei, step) @override @rank_zero_only @@ -318,12 +323,12 @@ def _get_next_version(self) -> int: return max(existing_versions) + 1 @staticmethod - def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]: + def _sanitize_params(params: dict[str, Any]) -> dict[str, Any]: params = _utils_sanitize_params(params) # logging of arrays with dimension > 1 is not supported, sanitize as string return {k: str(v) if hasattr(v, "ndim") and v.ndim > 1 else v for k, v in params.items()} - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: state = self.__dict__.copy() state["_experiment"] = None return state diff --git a/src/lightning/fabric/plugins/collectives/collective.py b/src/lightning/fabric/plugins/collectives/collective.py index 3b336b590aec8..9408fd87da400 100644 --- a/src/lightning/fabric/plugins/collectives/collective.py +++ b/src/lightning/fabric/plugins/collectives/collective.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, List, Optional +from typing import Any, Optional from torch import Tensor from typing_extensions import Self @@ -47,19 +47,19 @@ def all_reduce(self, tensor: Tensor, op: str) -> Tensor: ... def reduce(self, tensor: Tensor, dst: int, op: str) -> Tensor: ... @abstractmethod - def all_gather(self, tensor_list: List[Tensor], tensor: Tensor) -> List[Tensor]: ... + def all_gather(self, tensor_list: list[Tensor], tensor: Tensor) -> list[Tensor]: ... @abstractmethod - def gather(self, tensor: Tensor, gather_list: List[Tensor], dst: int = 0) -> List[Tensor]: ... + def gather(self, tensor: Tensor, gather_list: list[Tensor], dst: int = 0) -> list[Tensor]: ... @abstractmethod - def scatter(self, tensor: Tensor, scatter_list: List[Tensor], src: int = 0) -> Tensor: ... + def scatter(self, tensor: Tensor, scatter_list: list[Tensor], src: int = 0) -> Tensor: ... @abstractmethod - def reduce_scatter(self, output: Tensor, input_list: List[Tensor], op: str) -> Tensor: ... + def reduce_scatter(self, output: Tensor, input_list: list[Tensor], op: str) -> Tensor: ... @abstractmethod - def all_to_all(self, output_tensor_list: List[Tensor], input_tensor_list: List[Tensor]) -> List[Tensor]: ... + def all_to_all(self, output_tensor_list: list[Tensor], input_tensor_list: list[Tensor]) -> list[Tensor]: ... @abstractmethod def send(self, tensor: Tensor, dst: int, tag: int = 0) -> None: ... @@ -68,7 +68,7 @@ def send(self, tensor: Tensor, dst: int, tag: int = 0) -> None: ... def recv(self, tensor: Tensor, src: Optional[int] = None, tag: int = 0) -> Tensor: ... @abstractmethod - def barrier(self, device_ids: Optional[List[int]] = None) -> None: ... + def barrier(self, device_ids: Optional[list[int]] = None) -> None: ... @classmethod @abstractmethod diff --git a/src/lightning/fabric/plugins/collectives/single_device.py b/src/lightning/fabric/plugins/collectives/single_device.py index 9b635f6cdc7c1..73378715f2454 100644 --- a/src/lightning/fabric/plugins/collectives/single_device.py +++ b/src/lightning/fabric/plugins/collectives/single_device.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import Any from torch import Tensor from typing_extensions import override @@ -37,31 +37,31 @@ def reduce(self, tensor: Tensor, *_: Any, **__: Any) -> Tensor: return tensor @override - def all_gather(self, tensor_list: List[Tensor], tensor: Tensor, **__: Any) -> List[Tensor]: + def all_gather(self, tensor_list: list[Tensor], tensor: Tensor, **__: Any) -> list[Tensor]: return [tensor] @override - def gather(self, tensor: Tensor, *_: Any, **__: Any) -> List[Tensor]: + def gather(self, tensor: Tensor, *_: Any, **__: Any) -> list[Tensor]: return [tensor] @override def scatter( self, tensor: Tensor, - scatter_list: List[Tensor], + scatter_list: list[Tensor], *_: Any, **__: Any, ) -> Tensor: return scatter_list[0] @override - def reduce_scatter(self, output: Tensor, input_list: List[Tensor], *_: Any, **__: Any) -> Tensor: + def reduce_scatter(self, output: Tensor, input_list: list[Tensor], *_: Any, **__: Any) -> Tensor: return input_list[0] @override def all_to_all( - self, output_tensor_list: List[Tensor], input_tensor_list: List[Tensor], *_: Any, **__: Any - ) -> List[Tensor]: + self, output_tensor_list: list[Tensor], input_tensor_list: list[Tensor], *_: Any, **__: Any + ) -> list[Tensor]: return input_tensor_list @override diff --git a/src/lightning/fabric/plugins/collectives/torch_collective.py b/src/lightning/fabric/plugins/collectives/torch_collective.py index 0dea3033f3dff..81e15a33cb983 100644 --- a/src/lightning/fabric/plugins/collectives/torch_collective.py +++ b/src/lightning/fabric/plugins/collectives/torch_collective.py @@ -1,6 +1,6 @@ import datetime import os -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union import torch import torch.distributed as dist @@ -66,30 +66,30 @@ def reduce(self, tensor: Tensor, dst: int, op: Union[str, ReduceOp, RedOpType] = return tensor @override - def all_gather(self, tensor_list: List[Tensor], tensor: Tensor) -> List[Tensor]: + def all_gather(self, tensor_list: list[Tensor], tensor: Tensor) -> list[Tensor]: dist.all_gather(tensor_list, tensor, group=self.group) return tensor_list @override - def gather(self, tensor: Tensor, gather_list: List[Tensor], dst: int = 0) -> List[Tensor]: + def gather(self, tensor: Tensor, gather_list: list[Tensor], dst: int = 0) -> list[Tensor]: dist.gather(tensor, gather_list, dst, group=self.group) return gather_list @override - def scatter(self, tensor: Tensor, scatter_list: List[Tensor], src: int = 0) -> Tensor: + def scatter(self, tensor: Tensor, scatter_list: list[Tensor], src: int = 0) -> Tensor: dist.scatter(tensor, scatter_list, src, group=self.group) return tensor @override def reduce_scatter( - self, output: Tensor, input_list: List[Tensor], op: Union[str, ReduceOp, RedOpType] = "sum" + self, output: Tensor, input_list: list[Tensor], op: Union[str, ReduceOp, RedOpType] = "sum" ) -> Tensor: op = self._convert_to_native_op(op) dist.reduce_scatter(output, input_list, op=op, group=self.group) return output @override - def all_to_all(self, output_tensor_list: List[Tensor], input_tensor_list: List[Tensor]) -> List[Tensor]: + def all_to_all(self, output_tensor_list: list[Tensor], input_tensor_list: list[Tensor]) -> list[Tensor]: dist.all_to_all(output_tensor_list, input_tensor_list, group=self.group) return output_tensor_list @@ -102,28 +102,28 @@ def recv(self, tensor: Tensor, src: Optional[int] = None, tag: int = 0) -> Tenso dist.recv(tensor, src, tag=tag, group=self.group) # type: ignore[arg-type] return tensor - def all_gather_object(self, object_list: List[Any], obj: Any) -> List[Any]: + def all_gather_object(self, object_list: list[Any], obj: Any) -> list[Any]: dist.all_gather_object(object_list, obj, group=self.group) return object_list def broadcast_object_list( - self, object_list: List[Any], src: int, device: Optional[torch.device] = None - ) -> List[Any]: + self, object_list: list[Any], src: int, device: Optional[torch.device] = None + ) -> list[Any]: dist.broadcast_object_list(object_list, src, group=self.group, device=device) return object_list - def gather_object(self, obj: Any, object_gather_list: List[Any], dst: int = 0) -> List[Any]: + def gather_object(self, obj: Any, object_gather_list: list[Any], dst: int = 0) -> list[Any]: dist.gather_object(obj, object_gather_list, dst, group=self.group) return object_gather_list def scatter_object_list( - self, scatter_object_output_list: List[Any], scatter_object_input_list: List[Any], src: int = 0 - ) -> List[Any]: + self, scatter_object_output_list: list[Any], scatter_object_input_list: list[Any], src: int = 0 + ) -> list[Any]: dist.scatter_object_list(scatter_object_output_list, scatter_object_input_list, src, group=self.group) return scatter_object_output_list @override - def barrier(self, device_ids: Optional[List[int]] = None) -> None: + def barrier(self, device_ids: Optional[list[int]] = None) -> None: if self.group == dist.GroupMember.NON_GROUP_MEMBER: return dist.barrier(group=self.group, device_ids=device_ids) diff --git a/src/lightning/fabric/plugins/environments/lsf.py b/src/lightning/fabric/plugins/environments/lsf.py index 6a23006d3c1d9..f0a07d61d9f03 100644 --- a/src/lightning/fabric/plugins/environments/lsf.py +++ b/src/lightning/fabric/plugins/environments/lsf.py @@ -14,7 +14,6 @@ import logging import os import socket -from typing import Dict, List from typing_extensions import override @@ -144,14 +143,14 @@ def _get_node_rank(self) -> int: """ hosts = self._read_hosts() - count: Dict[str, int] = {} + count: dict[str, int] = {} for host in hosts: if host not in count: count[host] = len(count) return count[socket.gethostname()] @staticmethod - def _read_hosts() -> List[str]: + def _read_hosts() -> list[str]: """Read compute hosts that are a part of the compute job. LSF uses the Job Step Manager (JSM) to manage job steps. Job steps are executed by the JSM from "launch" nodes. diff --git a/src/lightning/fabric/plugins/io/checkpoint_io.py b/src/lightning/fabric/plugins/io/checkpoint_io.py index 79fc9e88b737e..3a33dac3335d1 100644 --- a/src/lightning/fabric/plugins/io/checkpoint_io.py +++ b/src/lightning/fabric/plugins/io/checkpoint_io.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Dict, Optional +from typing import Any, Optional from lightning.fabric.utilities.types import _PATH @@ -36,7 +36,7 @@ class CheckpointIO(ABC): """ @abstractmethod - def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: + def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: @@ -47,7 +47,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio """ @abstractmethod - def load_checkpoint(self, path: _PATH, map_location: Optional[Any] = None) -> Dict[str, Any]: + def load_checkpoint(self, path: _PATH, map_location: Optional[Any] = None) -> dict[str, Any]: """Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages. Args: diff --git a/src/lightning/fabric/plugins/io/torch_io.py b/src/lightning/fabric/plugins/io/torch_io.py index 02de1aa274a32..90a5f62ba7413 100644 --- a/src/lightning/fabric/plugins/io/torch_io.py +++ b/src/lightning/fabric/plugins/io/torch_io.py @@ -13,7 +13,7 @@ # limitations under the License. import logging import os -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Optional from typing_extensions import override @@ -34,7 +34,7 @@ class TorchCheckpointIO(CheckpointIO): """ @override - def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: + def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: @@ -60,7 +60,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio @override def load_checkpoint( self, path: _PATH, map_location: Optional[Callable] = lambda storage, loc: storage - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Loads checkpoint using :func:`torch.load`, with additional handling for ``fsspec`` remote loading of files. Args: diff --git a/src/lightning/fabric/plugins/io/xla.py b/src/lightning/fabric/plugins/io/xla.py index 5c154d81a9915..146fa2f33b510 100644 --- a/src/lightning/fabric/plugins/io/xla.py +++ b/src/lightning/fabric/plugins/io/xla.py @@ -13,7 +13,7 @@ # limitations under the License. import logging import os -from typing import Any, Dict, Optional +from typing import Any, Optional import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -41,7 +41,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @override - def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: + def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: diff --git a/src/lightning/fabric/plugins/precision/amp.py b/src/lightning/fabric/plugins/precision/amp.py index c624e821af28c..d5fc1f0c1cc2a 100644 --- a/src/lightning/fabric/plugins/precision/amp.py +++ b/src/lightning/fabric/plugins/precision/amp.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, ContextManager, Dict, Literal, Optional +from contextlib import AbstractContextManager +from typing import Any, Literal, Optional import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -59,7 +60,7 @@ def __init__( self._desired_input_dtype = torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16 @override - def forward_context(self) -> ContextManager: + def forward_context(self) -> AbstractContextManager: return torch.autocast(self.device, dtype=self._desired_input_dtype) @override @@ -93,13 +94,13 @@ def optimizer_step( return step_output @override - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: if self.scaler is not None: return self.scaler.state_dict() return {} @override - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: if self.scaler is not None: self.scaler.load_state_dict(state_dict) diff --git a/src/lightning/fabric/plugins/precision/bitsandbytes.py b/src/lightning/fabric/plugins/precision/bitsandbytes.py index 0f524dd67fad9..ecb1d8a442655 100644 --- a/src/lightning/fabric/plugins/precision/bitsandbytes.py +++ b/src/lightning/fabric/plugins/precision/bitsandbytes.py @@ -16,10 +16,11 @@ import math import os import warnings -from contextlib import ExitStack +from collections import OrderedDict +from contextlib import AbstractContextManager, ExitStack from functools import partial from types import ModuleType -from typing import Any, Callable, ContextManager, Literal, Optional, OrderedDict, Set, Tuple, Type, cast +from typing import Any, Callable, Literal, Optional, cast import torch from lightning_utilities import apply_to_collection @@ -43,7 +44,7 @@ class BitsandbytesPrecision(Precision): - """Plugin for quantizing weights with `bitsandbytes `__. + """Plugin for quantizing weights with `bitsandbytes `__. .. warning:: This is an :ref:`experimental ` feature. @@ -70,7 +71,7 @@ def __init__( self, mode: Literal["nf4", "nf4-dq", "fp4", "fp4-dq", "int8", "int8-training"], dtype: Optional[torch.dtype] = None, - ignore_modules: Optional[Set[str]] = None, + ignore_modules: Optional[set[str]] = None, ) -> None: _import_bitsandbytes() @@ -122,11 +123,11 @@ def convert_module(self, module: torch.nn.Module) -> torch.nn.Module: return module @override - def tensor_init_context(self) -> ContextManager: + def tensor_init_context(self) -> AbstractContextManager: return _DtypeContextManager(self.dtype) @override - def module_init_context(self) -> ContextManager: + def module_init_context(self) -> AbstractContextManager: if self.ignore_modules: # cannot patch the Linear class if the user wants to skip some submodules raise RuntimeError( @@ -144,7 +145,7 @@ def module_init_context(self) -> ContextManager: return stack @override - def forward_context(self) -> ContextManager: + def forward_context(self) -> AbstractContextManager: return _DtypeContextManager(self.dtype) @override @@ -175,7 +176,7 @@ def _ignore_missing_weights_hook(module: torch.nn.Module, incompatible_keys: _In def _replace_param( - param: torch.nn.Parameter, data: torch.Tensor, quant_state: Optional[Tuple] = None + param: torch.nn.Parameter, data: torch.Tensor, quant_state: Optional[tuple] = None ) -> torch.nn.Parameter: bnb = _import_bitsandbytes() @@ -184,11 +185,15 @@ def _replace_param( if param.device.type == "meta": if isinstance(param, bnb.nn.Params4bit): return bnb.nn.Params4bit( - data, + data=data, requires_grad=data.requires_grad, quant_state=quant_state, + blocksize=param.blocksize, compress_statistics=param.compress_statistics, quant_type=param.quant_type, + quant_storage=param.quant_storage, + module=param.module, + bnb_quantized=param.bnb_quantized, ) return torch.nn.Parameter(data, requires_grad=data.requires_grad) param.data = data @@ -322,6 +327,7 @@ def quantize_(self, weight: Optional[torch.Tensor] = None, device: Optional[torc return assert isinstance(self.weight, bnb.nn.Params4bit) self.weight = self.quantize(self.weight, weight, device) + self.weight.bnb_quantized = True @staticmethod def quantize( @@ -337,6 +343,7 @@ def quantize( blocksize=params4bit.blocksize, compress_statistics=params4bit.compress_statistics, quant_type=params4bit.quant_type, + quant_storage=params4bit.quant_storage, ) return _replace_param(params4bit, w_4bit, quant_state) @@ -412,7 +419,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: return bnb -def _convert_layers(module: torch.nn.Module, linear_cls: Type, ignore_modules: Set[str], prefix: str = "") -> None: +def _convert_layers(module: torch.nn.Module, linear_cls: type, ignore_modules: set[str], prefix: str = "") -> None: for name, child in module.named_children(): fullname = f"{prefix}.{name}" if prefix else name if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules): diff --git a/src/lightning/fabric/plugins/precision/deepspeed.py b/src/lightning/fabric/plugins/precision/deepspeed.py index 2fcaa38258e3a..526095008f376 100644 --- a/src/lightning/fabric/plugins/precision/deepspeed.py +++ b/src/lightning/fabric/plugins/precision/deepspeed.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import nullcontext -from typing import TYPE_CHECKING, Any, ContextManager, Literal +from contextlib import AbstractContextManager, nullcontext +from typing import TYPE_CHECKING, Any, Literal import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -68,13 +68,13 @@ def convert_module(self, module: Module) -> Module: return module @override - def tensor_init_context(self) -> ContextManager: + def tensor_init_context(self) -> AbstractContextManager: if "true" not in self.precision: return nullcontext() return _DtypeContextManager(self._desired_dtype) @override - def module_init_context(self) -> ContextManager: + def module_init_context(self) -> AbstractContextManager: return self.tensor_init_context() @override diff --git a/src/lightning/fabric/plugins/precision/double.py b/src/lightning/fabric/plugins/precision/double.py index 0a857499f3d34..9aa0365a55e70 100644 --- a/src/lightning/fabric/plugins/precision/double.py +++ b/src/lightning/fabric/plugins/precision/double.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, ContextManager, Literal +from contextlib import AbstractContextManager +from typing import Any, Literal import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -33,15 +34,15 @@ def convert_module(self, module: Module) -> Module: return module.double() @override - def tensor_init_context(self) -> ContextManager: + def tensor_init_context(self) -> AbstractContextManager: return _DtypeContextManager(torch.double) @override - def module_init_context(self) -> ContextManager: + def module_init_context(self) -> AbstractContextManager: return self.tensor_init_context() @override - def forward_context(self) -> ContextManager: + def forward_context(self) -> AbstractContextManager: return self.tensor_init_context() @override diff --git a/src/lightning/fabric/plugins/precision/fsdp.py b/src/lightning/fabric/plugins/precision/fsdp.py index 179fc21cdd90d..270a67e3a2338 100644 --- a/src/lightning/fabric/plugins/precision/fsdp.py +++ b/src/lightning/fabric/plugins/precision/fsdp.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, ContextManager, Dict, Literal, Optional +from contextlib import AbstractContextManager +from typing import TYPE_CHECKING, Any, Literal, Optional import torch from lightning_utilities import apply_to_collection @@ -73,6 +74,12 @@ def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradSca } self._desired_input_dtype = precision_to_type[self.precision] + @override + def convert_module(self, module: Module) -> Module: + if "true" in self.precision: + return module.to(dtype=self._desired_input_dtype) + return module + @property def mixed_precision_config(self) -> "TorchMixedPrecision": from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision @@ -100,15 +107,15 @@ def mixed_precision_config(self) -> "TorchMixedPrecision": ) @override - def tensor_init_context(self) -> ContextManager: + def tensor_init_context(self) -> AbstractContextManager: return _DtypeContextManager(self._desired_input_dtype) @override - def module_init_context(self) -> ContextManager: + def module_init_context(self) -> AbstractContextManager: return _DtypeContextManager(self.mixed_precision_config.param_dtype or torch.float32) @override - def forward_context(self) -> ContextManager: + def forward_context(self) -> AbstractContextManager: if "mixed" in self.precision: return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16)) return self.tensor_init_context() @@ -150,12 +157,12 @@ def unscale_gradients(self, optimizer: Optimizer) -> None: scaler.unscale_(optimizer) @override - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: if self.scaler is not None: return self.scaler.state_dict() return {} @override - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: if self.scaler is not None: self.scaler.load_state_dict(state_dict) diff --git a/src/lightning/fabric/plugins/precision/half.py b/src/lightning/fabric/plugins/precision/half.py index 32ca7da815213..fcb28ad33274c 100644 --- a/src/lightning/fabric/plugins/precision/half.py +++ b/src/lightning/fabric/plugins/precision/half.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, ContextManager, Literal +from contextlib import AbstractContextManager +from typing import Any, Literal import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -42,15 +43,15 @@ def convert_module(self, module: Module) -> Module: return module.to(dtype=self._desired_input_dtype) @override - def tensor_init_context(self) -> ContextManager: + def tensor_init_context(self) -> AbstractContextManager: return _DtypeContextManager(self._desired_input_dtype) @override - def module_init_context(self) -> ContextManager: + def module_init_context(self) -> AbstractContextManager: return self.tensor_init_context() @override - def forward_context(self) -> ContextManager: + def forward_context(self) -> AbstractContextManager: return self.tensor_init_context() @override diff --git a/src/lightning/fabric/plugins/precision/precision.py b/src/lightning/fabric/plugins/precision/precision.py index fbff54f8e3595..1dfab2a7bc649 100644 --- a/src/lightning/fabric/plugins/precision/precision.py +++ b/src/lightning/fabric/plugins/precision/precision.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import nullcontext -from typing import Any, ContextManager, Dict, Literal, Optional, Union +from contextlib import AbstractContextManager, nullcontext +from typing import Any, Literal, Optional, Union from torch import Tensor from torch.nn import Module @@ -53,11 +53,11 @@ def convert_module(self, module: Module) -> Module: """ return module - def tensor_init_context(self) -> ContextManager: + def tensor_init_context(self) -> AbstractContextManager: """Controls how tensors get created (device, dtype).""" return nullcontext() - def module_init_context(self) -> ContextManager: + def module_init_context(self) -> AbstractContextManager: """Instantiate module parameters or tensors in the precision type this plugin handles. This is optional and depends on the precision limitations during optimization. @@ -65,7 +65,7 @@ def module_init_context(self) -> ContextManager: """ return nullcontext() - def forward_context(self) -> ContextManager: + def forward_context(self) -> AbstractContextManager: """A contextmanager for managing model forward/training_step/evaluation_step/predict_step.""" return nullcontext() @@ -135,7 +135,7 @@ def main_params(self, optimizer: Optimizer) -> _PARAMETERS: def unscale_gradients(self, optimizer: Optimizer) -> None: return - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: """Called when saving a checkpoint, implement to generate precision plugin state_dict. Returns: @@ -144,7 +144,7 @@ def state_dict(self) -> Dict[str, Any]: """ return {} - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: """Called when loading a checkpoint, implement to reload precision plugin state given precision plugin state_dict. diff --git a/src/lightning/fabric/plugins/precision/transformer_engine.py b/src/lightning/fabric/plugins/precision/transformer_engine.py index cb5296b21fc39..c3ef84a453e73 100644 --- a/src/lightning/fabric/plugins/precision/transformer_engine.py +++ b/src/lightning/fabric/plugins/precision/transformer_engine.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from contextlib import ExitStack -from typing import TYPE_CHECKING, Any, ContextManager, Literal, Mapping, Optional, Union +from collections.abc import Mapping +from contextlib import AbstractContextManager, ExitStack +from typing import TYPE_CHECKING, Any, Literal, Optional, Union import torch from lightning_utilities import apply_to_collection @@ -106,11 +107,11 @@ def convert_module(self, module: torch.nn.Module) -> torch.nn.Module: return module @override - def tensor_init_context(self) -> ContextManager: + def tensor_init_context(self) -> AbstractContextManager: return _DtypeContextManager(self.weights_dtype) @override - def module_init_context(self) -> ContextManager: + def module_init_context(self) -> AbstractContextManager: dtype_ctx = self.tensor_init_context() stack = ExitStack() if self.replace_layers: @@ -125,7 +126,7 @@ def module_init_context(self) -> ContextManager: return stack @override - def forward_context(self) -> ContextManager: + def forward_context(self) -> AbstractContextManager: dtype_ctx = _DtypeContextManager(self.weights_dtype) fallback_autocast_ctx = torch.autocast(device_type="cuda", dtype=self.fallback_compute_dtype) import transformer_engine.pytorch as te diff --git a/src/lightning/fabric/plugins/precision/utils.py b/src/lightning/fabric/plugins/precision/utils.py index 887dbc937a1f6..8362384cb1042 100644 --- a/src/lightning/fabric/plugins/precision/utils.py +++ b/src/lightning/fabric/plugins/precision/utils.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Mapping, Type, Union +from collections.abc import Mapping +from typing import Any, Union import torch from torch import Tensor @@ -43,7 +44,7 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: class _ClassReplacementContextManager: """A context manager to monkeypatch classes.""" - def __init__(self, mapping: Mapping[str, Type]) -> None: + def __init__(self, mapping: Mapping[str, type]) -> None: self._mapping = mapping self._originals = {} self._modules = {} diff --git a/src/lightning/fabric/strategies/ddp.py b/src/lightning/fabric/strategies/ddp.py index c38780655ce6e..ce47e4e403c34 100644 --- a/src/lightning/fabric/strategies/ddp.py +++ b/src/lightning/fabric/strategies/ddp.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import nullcontext +from contextlib import AbstractContextManager, nullcontext from datetime import timedelta -from typing import Any, ContextManager, Dict, List, Literal, Optional, Union +from typing import Any, Literal, Optional, Union import torch import torch.distributed @@ -55,7 +55,7 @@ class DDPStrategy(ParallelStrategy): def __init__( self, accelerator: Optional[Accelerator] = None, - parallel_devices: Optional[List[torch.device]] = None, + parallel_devices: Optional[list[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, precision: Optional[Precision] = None, @@ -99,7 +99,7 @@ def num_processes(self) -> int: @property @override - def distributed_sampler_kwargs(self) -> Dict[str, Any]: + def distributed_sampler_kwargs(self) -> dict[str, Any]: return {"num_replicas": (self.num_nodes * self.num_processes), "rank": self.global_rank} @property @@ -171,14 +171,14 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: return obj[0] @override - def get_module_state_dict(self, module: Module) -> Dict[str, Union[Any, Tensor]]: + def get_module_state_dict(self, module: Module) -> dict[str, Union[Any, Tensor]]: if isinstance(module, DistributedDataParallel): module = module.module return super().get_module_state_dict(module) @override def load_module_state_dict( - self, module: Module, state_dict: Dict[str, Union[Any, Tensor]], strict: bool = True + self, module: Module, state_dict: dict[str, Union[Any, Tensor]], strict: bool = True ) -> None: if isinstance(module, DistributedDataParallel): module = module.module @@ -225,13 +225,13 @@ def _set_world_ranks(self) -> None: # additionally, for some implementations, the setter is a no-op, so it's safer to access the getter rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank - def _determine_ddp_device_ids(self) -> Optional[List[int]]: + def _determine_ddp_device_ids(self) -> Optional[list[int]]: return None if self.root_device.type == "cpu" else [self.root_device.index] class _DDPBackwardSyncControl(_BackwardSyncControl): @override - def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager: + def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager: """Blocks gradient synchronization inside the :class:`~torch.nn.parallel.distributed.DistributedDataParallel` wrapper.""" if not enabled: diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 93a17f10c8998..1e94fa1166f93 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -16,10 +16,12 @@ import logging import os import platform -from contextlib import ExitStack +from collections.abc import Mapping +from contextlib import AbstractContextManager, ExitStack +from datetime import timedelta from itertools import chain from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, List, Mapping, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch from lightning_utilities.core.imports import RequirementCache @@ -28,6 +30,7 @@ from typing_extensions import override from lightning.fabric.accelerators import Accelerator, CUDAAccelerator +from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment from lightning.fabric.plugins.precision import Precision from lightning.fabric.strategies.ddp import DDPStrategy @@ -80,9 +83,9 @@ def __init__( reduce_bucket_size: int = 200_000_000, zero_allow_untested_optimizer: bool = True, logging_batch_size_per_gpu: Optional[int] = None, - config: Optional[Union[_PATH, Dict[str, Any]]] = None, + config: Optional[Union[_PATH, dict[str, Any]]] = None, logging_level: int = logging.WARN, - parallel_devices: Optional[List[torch.device]] = None, + parallel_devices: Optional[list[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, loss_scale: float = 0, initial_scale_power: int = 16, @@ -96,6 +99,7 @@ def __init__( load_full_weights: bool = False, precision: Optional[Precision] = None, process_group_backend: Optional[str] = None, + timeout: Optional[timedelta] = default_pg_timeout, ) -> None: """Provides capabilities to run training using the DeepSpeed library, with training optimizations for large billion parameter models. `For more information: https://pytorch- @@ -240,6 +244,7 @@ def __init__( process_group_backend=process_group_backend, ) self._backward_sync_control = None # DeepSpeed handles gradient accumulation internally + self._timeout: Optional[timedelta] = timeout self.config = self._load_config(config) if self.config is None: @@ -302,7 +307,7 @@ def zero_stage_3(self) -> bool: @property @override - def distributed_sampler_kwargs(self) -> Dict[str, int]: + def distributed_sampler_kwargs(self) -> dict[str, int]: return {"num_replicas": self.world_size, "rank": self.global_rank} @property @@ -311,8 +316,8 @@ def model(self) -> "DeepSpeedEngine": @override def setup_module_and_optimizers( - self, module: Module, optimizers: List[Optimizer] - ) -> Tuple["DeepSpeedEngine", List[Optimizer]]: + self, module: Module, optimizers: list[Optimizer] + ) -> tuple["DeepSpeedEngine", list[Optimizer]]: """Set up a model and multiple optimizers together. Currently, only a single optimizer is supported. @@ -352,7 +357,7 @@ def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: raise NotImplementedError(self._err_msg_joint_setup_required()) @override - def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager: + def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager: if self.zero_stage_3 and empty_init is False: raise NotImplementedError( f"`{empty_init=}` is not a valid choice with `DeepSpeedStrategy` when ZeRO stage 3 is enabled." @@ -365,7 +370,7 @@ def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManag return stack @override - def module_sharded_context(self) -> ContextManager: + def module_sharded_context(self) -> AbstractContextManager: # Current limitation in Fabric: The config needs to be fully determined at the time of calling the context # manager. Later modifications through e.g. `Fabric.setup()` won't have an effect here. @@ -382,9 +387,9 @@ def module_sharded_context(self) -> ContextManager: def save_checkpoint( self, path: _PATH, - state: Dict[str, Union[Module, Optimizer, Any]], + state: dict[str, Union[Module, Optimizer, Any]], storage_options: Optional[Any] = None, - filter: Optional[Dict[str, Callable[[str, Any], bool]]] = None, + filter: Optional[dict[str, Callable[[str, Any], bool]]] = None, ) -> None: """Save model, optimizer, and other state in a checkpoint directory. @@ -447,9 +452,9 @@ def save_checkpoint( def load_checkpoint( self, path: _PATH, - state: Optional[Union[Module, Optimizer, Dict[str, Union[Module, Optimizer, Any]]]] = None, + state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None, strict: bool = True, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Load the contents from a checkpoint and restore the state of the given objects. Args: @@ -595,10 +600,10 @@ def _initialize_engine( self, model: Module, optimizer: Optional[Optimizer] = None, - ) -> Tuple["DeepSpeedEngine", Optimizer]: + ) -> tuple["DeepSpeedEngine", Optimizer]: """Initialize one model and one optimizer with an optional learning rate scheduler. - This calls :func:`deepspeed.initialize` internally. + This calls ``deepspeed.initialize`` internally. """ import deepspeed @@ -647,7 +652,9 @@ def _init_deepspeed_distributed(self) -> None: f"MEMBER: {self.global_rank + 1}/{self.world_size}" ) self._process_group_backend = self._get_process_group_backend() - deepspeed.init_distributed(self._process_group_backend, distributed_port=self.cluster_environment.main_port) + deepspeed.init_distributed( + self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout + ) def _set_node_environment_variables(self) -> None: assert self.cluster_environment is not None @@ -714,7 +721,7 @@ def _create_default_config( overlap_events: bool, thread_count: int, **zero_kwargs: Any, - ) -> Dict: + ) -> dict: cfg = { "activation_checkpointing": { "partition_activations": partition_activations, @@ -769,9 +776,9 @@ def _restore_zero_state(self, module: Module, ckpt: Mapping[str, Any]) -> None: import deepspeed def load(module: torch.nn.Module, prefix: str = "") -> None: - missing_keys: List[str] = [] - unexpected_keys: List[str] = [] - error_msgs: List[str] = [] + missing_keys: list[str] = [] + unexpected_keys: list[str] = [] + error_msgs: list[str] = [] state_dict = ckpt["state_dict"] # copy state_dict so _load_from_state_dict can modify it @@ -802,7 +809,7 @@ def load(module: torch.nn.Module, prefix: str = "") -> None: load(module, prefix="") - def _load_config(self, config: Optional[Union[_PATH, Dict[str, Any]]]) -> Optional[Dict[str, Any]]: + def _load_config(self, config: Optional[Union[_PATH, dict[str, Any]]]) -> Optional[dict[str, Any]]: if config is None and self.DEEPSPEED_ENV_VAR in os.environ: rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable") config = os.environ[self.DEEPSPEED_ENV_VAR] @@ -817,14 +824,14 @@ def _load_config(self, config: Optional[Union[_PATH, Dict[str, Any]]]) -> Option return config -def _get_deepspeed_engines_from_state(state: Dict[str, Any]) -> List["DeepSpeedEngine"]: +def _get_deepspeed_engines_from_state(state: dict[str, Any]) -> list["DeepSpeedEngine"]: from deepspeed import DeepSpeedEngine modules = chain(*(module.modules() for module in state.values() if isinstance(module, Module))) return [engine for engine in modules if isinstance(engine, DeepSpeedEngine)] -def _validate_state_keys(state: Dict[str, Any]) -> None: +def _validate_state_keys(state: dict[str, Any]) -> None: # DeepSpeed merges the client state into its internal engine state when saving, but it does not check for # colliding keys from the user. We explicitly check it here: deepspeed_internal_keys = { @@ -851,7 +858,7 @@ def _validate_state_keys(state: Dict[str, Any]) -> None: ) -def _validate_device_index_selection(parallel_devices: List[torch.device]) -> None: +def _validate_device_index_selection(parallel_devices: list[torch.device]) -> None: selected_device_indices = [device.index for device in parallel_devices] expected_device_indices = list(range(len(parallel_devices))) if selected_device_indices != expected_device_indices: @@ -903,7 +910,7 @@ def _validate_checkpoint_directory(path: _PATH) -> None: def _format_precision_config( - config: Dict[str, Any], + config: dict[str, Any], precision: str, loss_scale: float, loss_scale_window: int, diff --git a/src/lightning/fabric/strategies/dp.py b/src/lightning/fabric/strategies/dp.py index 2fed307af5129..f407040649c54 100644 --- a/src/lightning/fabric/strategies/dp.py +++ b/src/lightning/fabric/strategies/dp.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union import torch from torch import Tensor @@ -35,7 +35,7 @@ class DataParallelStrategy(ParallelStrategy): def __init__( self, accelerator: Optional[Accelerator] = None, - parallel_devices: Optional[List[torch.device]] = None, + parallel_devices: Optional[list[torch.device]] = None, checkpoint_io: Optional[CheckpointIO] = None, precision: Optional[Precision] = None, ): @@ -95,14 +95,14 @@ def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool: return decision @override - def get_module_state_dict(self, module: Module) -> Dict[str, Union[Any, Tensor]]: + def get_module_state_dict(self, module: Module) -> dict[str, Union[Any, Tensor]]: if isinstance(module, DataParallel): module = module.module return super().get_module_state_dict(module) @override def load_module_state_dict( - self, module: Module, state_dict: Dict[str, Union[Any, Tensor]], strict: bool = True + self, module: Module, state_dict: dict[str, Union[Any, Tensor]], strict: bool = True ) -> None: if isinstance(module, DataParallel): module = module.module diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index e7fdd29f6287f..9dd5b2c62d4c9 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -13,7 +13,8 @@ # limitations under the License. import shutil import warnings -from contextlib import ExitStack, nullcontext +from collections.abc import Generator +from contextlib import AbstractContextManager, ExitStack, nullcontext from datetime import timedelta from functools import partial from pathlib import Path @@ -21,15 +22,8 @@ TYPE_CHECKING, Any, Callable, - ContextManager, - Dict, - Generator, - List, Literal, Optional, - Set, - Tuple, - Type, Union, ) @@ -78,7 +72,7 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision, ShardingStrategy from torch.distributed.fsdp.wrap import ModuleWrapPolicy - _POLICY = Union[Set[Type[Module]], Callable[[Module, bool, int], bool], ModuleWrapPolicy] + _POLICY = Union[set[type[Module]], Callable[[Module, bool, int], bool], ModuleWrapPolicy] _SHARDING_STRATEGY = Union[ShardingStrategy, Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"]] @@ -143,7 +137,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded): def __init__( self, accelerator: Optional[Accelerator] = None, - parallel_devices: Optional[List[torch.device]] = None, + parallel_devices: Optional[list[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, precision: Optional[Precision] = None, process_group_backend: Optional[str] = None, @@ -151,11 +145,11 @@ def __init__( cpu_offload: Union[bool, "CPUOffload", None] = None, mixed_precision: Optional["MixedPrecision"] = None, auto_wrap_policy: Optional["_POLICY"] = None, - activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None, + activation_checkpointing: Optional[Union[type[Module], list[type[Module]]]] = None, activation_checkpointing_policy: Optional["_POLICY"] = None, sharding_strategy: "_SHARDING_STRATEGY" = "FULL_SHARD", state_dict_type: Literal["full", "sharded"] = "sharded", - device_mesh: Optional[Union[Tuple[int], "DeviceMesh"]] = None, + device_mesh: Optional[Union[tuple[int], "DeviceMesh"]] = None, **kwargs: Any, ) -> None: super().__init__( @@ -216,7 +210,7 @@ def num_processes(self) -> int: @property @override - def distributed_sampler_kwargs(self) -> Dict[str, Any]: + def distributed_sampler_kwargs(self) -> dict[str, Any]: return {"num_replicas": (self.num_nodes * self.num_processes), "rank": self.global_rank} @property @@ -267,8 +261,8 @@ def setup_environment(self) -> None: @override def setup_module_and_optimizers( - self, module: Module, optimizers: List[Optimizer] - ) -> Tuple[Module, List[Optimizer]]: + self, module: Module, optimizers: list[Optimizer] + ) -> tuple[Module, list[Optimizer]]: """Wraps the model into a :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module and sets `use_orig_params=True` to keep the reference to the original parameters in the optimizer.""" use_orig_params = self._fsdp_kwargs.get("use_orig_params") @@ -340,7 +334,7 @@ def module_to_device(self, module: Module) -> None: pass @override - def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager: + def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager: precision_init_ctx = self.precision.module_init_context() module_sharded_ctx = self.module_sharded_context() stack = ExitStack() @@ -354,7 +348,7 @@ def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManag return stack @override - def module_sharded_context(self) -> ContextManager: + def module_sharded_context(self) -> AbstractContextManager: from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel from torch.distributed.fsdp.wrap import enable_wrap @@ -419,9 +413,9 @@ def clip_gradients_norm( def save_checkpoint( self, path: _PATH, - state: Dict[str, Union[Module, Optimizer, Any]], + state: dict[str, Union[Module, Optimizer, Any]], storage_options: Optional[Any] = None, - filter: Optional[Dict[str, Callable[[str, Any], bool]]] = None, + filter: Optional[dict[str, Callable[[str, Any], bool]]] = None, ) -> None: """Save model, optimizer, and other state to a checkpoint on disk. @@ -473,8 +467,8 @@ def save_checkpoint( # replace the modules and optimizer objects in the state with their local state dict # and separate the user's metadata - converted_state: Dict[str, Any] = {} - metadata: Dict[str, Any] = {} + converted_state: dict[str, Any] = {} + metadata: dict[str, Any] = {} with state_dict_ctx: for key, obj in state.items(): converted: Any @@ -499,7 +493,7 @@ def save_checkpoint( shutil.rmtree(path) state_dict_ctx = _get_full_state_dict_context(module, world_size=self.world_size) - full_state: Dict[str, Any] = {} + full_state: dict[str, Any] = {} with state_dict_ctx: for key, obj in state.items(): if isinstance(obj, Module): @@ -519,9 +513,9 @@ def save_checkpoint( def load_checkpoint( self, path: _PATH, - state: Optional[Union[Module, Optimizer, Dict[str, Union[Module, Optimizer, Any]]]] = None, + state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None, strict: bool = True, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Load the contents from a checkpoint and restore the state of the given objects.""" if not state: raise ValueError( @@ -683,9 +677,9 @@ def _set_world_ranks(self) -> None: def _activation_checkpointing_kwargs( - activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]], + activation_checkpointing: Optional[Union[type[Module], list[type[Module]]]], activation_checkpointing_policy: Optional["_POLICY"], -) -> Dict: +) -> dict: if activation_checkpointing is None and activation_checkpointing_policy is None: return {} if activation_checkpointing is not None and activation_checkpointing_policy is not None: @@ -707,7 +701,7 @@ def _activation_checkpointing_kwargs( return {"auto_wrap_policy": activation_checkpointing_policy} -def _auto_wrap_policy_kwargs(policy: Optional["_POLICY"], kwargs: Dict) -> Dict: +def _auto_wrap_policy_kwargs(policy: Optional["_POLICY"], kwargs: dict) -> dict: if policy is None: return kwargs if isinstance(policy, set): @@ -719,7 +713,7 @@ def _auto_wrap_policy_kwargs(policy: Optional["_POLICY"], kwargs: Dict) -> Dict: return kwargs -def _setup_activation_checkpointing(module: Module, activation_checkpointing_kwargs: Dict) -> None: +def _setup_activation_checkpointing(module: Module, activation_checkpointing_kwargs: dict) -> None: if not activation_checkpointing_kwargs: return @@ -745,7 +739,7 @@ def _setup_activation_checkpointing(module: Module, activation_checkpointing_kwa class _FSDPBackwardSyncControl(_BackwardSyncControl): @override - def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager: + def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager: """Blocks gradient synchronization inside the :class:`~torch.distributed.fsdp.FullyShardedDataParallel` wrapper.""" if not enabled: @@ -768,7 +762,7 @@ def _init_cpu_offload(cpu_offload: Optional[Union[bool, "CPUOffload"]]) -> "CPUO return cpu_offload if isinstance(cpu_offload, CPUOffload) else CPUOffload(offload_params=bool(cpu_offload)) -def _init_sharding_strategy(sharding_strategy: "_SHARDING_STRATEGY", kwargs: Dict) -> "ShardingStrategy": +def _init_sharding_strategy(sharding_strategy: "_SHARDING_STRATEGY", kwargs: dict) -> "ShardingStrategy": from torch.distributed.fsdp import ShardingStrategy if kwargs.get("process_group") is not None and kwargs.get("device_mesh") is not None: @@ -858,7 +852,7 @@ def _move_torchmetrics_to_device(module: torch.nn.Module, device: torch.device) metric.to(device) # `.to()` is in-place -def _distributed_checkpoint_save(converted_state: Dict[str, Any], path: Path) -> None: +def _distributed_checkpoint_save(converted_state: dict[str, Any], path: Path) -> None: if _TORCH_GREATER_EQUAL_2_3: from torch.distributed.checkpoint import save @@ -877,7 +871,7 @@ def _distributed_checkpoint_save(converted_state: Dict[str, Any], path: Path) -> save(converted_state, writer) -def _distributed_checkpoint_load(module_state: Dict[str, Any], path: Path) -> None: +def _distributed_checkpoint_load(module_state: dict[str, Any], path: Path) -> None: if _TORCH_GREATER_EQUAL_2_3: from torch.distributed.checkpoint import load diff --git a/src/lightning/fabric/strategies/launchers/multiprocessing.py b/src/lightning/fabric/strategies/launchers/multiprocessing.py index 14a063f28f336..d9b96dca5471d 100644 --- a/src/lightning/fabric/strategies/launchers/multiprocessing.py +++ b/src/lightning/fabric/strategies/launchers/multiprocessing.py @@ -16,7 +16,7 @@ from dataclasses import dataclass from multiprocessing.queues import SimpleQueue from textwrap import dedent -from typing import TYPE_CHECKING, Any, Callable, Dict, Literal, Optional +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional import torch import torch.backends.cudnn @@ -167,7 +167,7 @@ class _GlobalStateSnapshot: use_deterministic_algorithms: bool use_deterministic_algorithms_warn_only: bool cudnn_benchmark: bool - rng_states: Dict[str, Any] + rng_states: dict[str, Any] @classmethod def capture(cls) -> "_GlobalStateSnapshot": diff --git a/src/lightning/fabric/strategies/launchers/subprocess_script.py b/src/lightning/fabric/strategies/launchers/subprocess_script.py index 63ae8b0beee4b..a28fe971c7ac4 100644 --- a/src/lightning/fabric/strategies/launchers/subprocess_script.py +++ b/src/lightning/fabric/strategies/launchers/subprocess_script.py @@ -18,7 +18,8 @@ import sys import threading import time -from typing import Any, Callable, List, Optional, Sequence, Tuple +from collections.abc import Sequence +from typing import Any, Callable, Optional from lightning_utilities.core.imports import RequirementCache from typing_extensions import override @@ -80,7 +81,7 @@ def __init__( self.cluster_environment = cluster_environment self.num_processes = num_processes self.num_nodes = num_nodes - self.procs: List[subprocess.Popen] = [] # launched child subprocesses, does not include the launcher + self.procs: list[subprocess.Popen] = [] # launched child subprocesses, does not include the launcher @property @override @@ -162,7 +163,7 @@ def _basic_subprocess_cmd() -> Sequence[str]: return [sys.executable, "-m", __main__.__spec__.name] + sys.argv[1:] -def _hydra_subprocess_cmd(local_rank: int) -> Tuple[Sequence[str], str]: +def _hydra_subprocess_cmd(local_rank: int) -> tuple[Sequence[str], str]: from hydra.core.hydra_config import HydraConfig from hydra.utils import get_original_cwd, to_absolute_path @@ -183,13 +184,13 @@ def _hydra_subprocess_cmd(local_rank: int) -> Tuple[Sequence[str], str]: return command, cwd -def _launch_process_observer(child_processes: List[subprocess.Popen]) -> None: +def _launch_process_observer(child_processes: list[subprocess.Popen]) -> None: """Launches a thread that runs along the main process and monitors the health of all processes.""" _ChildProcessObserver(child_processes=child_processes, main_pid=os.getpid()).start() class _ChildProcessObserver(threading.Thread): - def __init__(self, main_pid: int, child_processes: List[subprocess.Popen], sleep_period: int = 5) -> None: + def __init__(self, main_pid: int, child_processes: list[subprocess.Popen], sleep_period: int = 5) -> None: super().__init__(daemon=True, name="child-process-observer") # thread stops if the main process exits self._main_pid = main_pid self._child_processes = child_processes diff --git a/src/lightning/fabric/strategies/model_parallel.py b/src/lightning/fabric/strategies/model_parallel.py index 86b93d35e66f3..ad1fc19074d06 100644 --- a/src/lightning/fabric/strategies/model_parallel.py +++ b/src/lightning/fabric/strategies/model_parallel.py @@ -13,10 +13,11 @@ # limitations under the License. import itertools import shutil -from contextlib import ExitStack +from collections.abc import Generator +from contextlib import AbstractContextManager, ExitStack from datetime import timedelta from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Generator, Literal, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, TypeVar, Union import torch from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only @@ -144,7 +145,7 @@ def num_processes(self) -> int: @property @override - def distributed_sampler_kwargs(self) -> Dict[str, Any]: + def distributed_sampler_kwargs(self) -> dict[str, Any]: assert self.device_mesh is not None data_parallel_mesh = self.device_mesh["data_parallel"] return {"num_replicas": data_parallel_mesh.size(), "rank": data_parallel_mesh.get_local_rank()} @@ -194,7 +195,7 @@ def module_to_device(self, module: Module) -> None: pass @override - def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager: + def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager: precision_init_ctx = self.precision.module_init_context() stack = ExitStack() if empty_init: @@ -234,9 +235,9 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: def save_checkpoint( self, path: _PATH, - state: Dict[str, Union[Module, Optimizer, Any]], + state: dict[str, Union[Module, Optimizer, Any]], storage_options: Optional[Any] = None, - filter: Optional[Dict[str, Callable[[str, Any], bool]]] = None, + filter: Optional[dict[str, Callable[[str, Any], bool]]] = None, ) -> None: """Save model, optimizer, and other state to a checkpoint on disk. @@ -272,9 +273,9 @@ def save_checkpoint( def load_checkpoint( self, path: _PATH, - state: Optional[Union[Module, Optimizer, Dict[str, Union[Module, Optimizer, Any]]]] = None, + state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None, strict: bool = True, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Load the contents from a checkpoint and restore the state of the given objects.""" if not state: raise ValueError( @@ -318,12 +319,12 @@ def _set_world_ranks(self) -> None: class _ParallelBackwardSyncControl(_BackwardSyncControl): @override - def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager: + def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager: """Blocks gradient synchronization inside the FSDP2 modules.""" return _FSDPNoSync(module=module, enabled=enabled) -class _FSDPNoSync(ContextManager): +class _FSDPNoSync(AbstractContextManager): def __init__(self, module: Module, enabled: bool) -> None: self._module = module self._enabled = enabled @@ -344,10 +345,10 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: def _save_checkpoint( path: Path, - state: Dict[str, Union[Module, Optimizer, Any]], + state: dict[str, Union[Module, Optimizer, Any]], full_state_dict: bool, rank: int, - filter: Optional[Dict[str, Callable[[str, Any], bool]]] = None, + filter: Optional[dict[str, Callable[[str, Any], bool]]] = None, ) -> None: if path.is_dir() and full_state_dict and not _is_sharded_checkpoint(path): raise IsADirectoryError(f"The checkpoint path exists and is a directory: {path}") @@ -373,8 +374,8 @@ def _save_checkpoint( # replace the modules and optimizer objects in the state with their local state dict # and separate the user's metadata - converted_state: Dict[str, Any] = {} - metadata: Dict[str, Any] = {} + converted_state: dict[str, Any] = {} + metadata: dict[str, Any] = {} for key, obj in state.items(): converted: Any if isinstance(obj, Module): @@ -405,10 +406,10 @@ def _save_checkpoint( def _load_checkpoint( path: Path, - state: Dict[str, Union[Module, Optimizer, Any]], + state: dict[str, Union[Module, Optimizer, Any]], strict: bool = True, optimizer_states_from_list: bool = False, -) -> Dict[str, Any]: +) -> dict[str, Any]: from torch.distributed.checkpoint.state_dict import ( StateDictOptions, get_model_state_dict, @@ -537,7 +538,7 @@ def _load_raw_module_state_from_path(path: Path, module: Module, world_size: int def _load_raw_module_state( - state_dict: Dict[str, Any], module: Module, world_size: int = 1, strict: bool = True + state_dict: dict[str, Any], module: Module, world_size: int = 1, strict: bool = True ) -> None: """Loads the state dict into the module by gathering all weights first and then and writing back to each shard.""" from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -583,7 +584,7 @@ def _named_parameters_and_buffers_to_load(module: Module) -> Generator: yield param_name, param -def _rekey_optimizer_state_if_needed(optimizer_state_dict: Dict[str, Any], module: Module) -> Dict[str, Any]: +def _rekey_optimizer_state_if_needed(optimizer_state_dict: dict[str, Any], module: Module) -> dict[str, Any]: """Handles the case where the optimizer state is saved from a normal optimizer and converts the keys to parameter names.""" from torch.distributed.fsdp import FullyShardedDataParallel as FSDP diff --git a/src/lightning/fabric/strategies/parallel.py b/src/lightning/fabric/strategies/parallel.py index a12a0611c90ab..d9bc1a03d1bb5 100644 --- a/src/lightning/fabric/strategies/parallel.py +++ b/src/lightning/fabric/strategies/parallel.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC -from typing import Any, Dict, List, Optional +from typing import Any, Optional import torch from torch import Tensor @@ -33,7 +33,7 @@ class ParallelStrategy(Strategy, ABC): def __init__( self, accelerator: Optional[Accelerator] = None, - parallel_devices: Optional[List[torch.device]] = None, + parallel_devices: Optional[list[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, precision: Optional[Precision] = None, @@ -64,15 +64,15 @@ def is_global_zero(self) -> bool: return self.global_rank == 0 @property - def parallel_devices(self) -> Optional[List[torch.device]]: + def parallel_devices(self) -> Optional[list[torch.device]]: return self._parallel_devices @parallel_devices.setter - def parallel_devices(self, parallel_devices: Optional[List[torch.device]]) -> None: + def parallel_devices(self, parallel_devices: Optional[list[torch.device]]) -> None: self._parallel_devices = parallel_devices @property - def distributed_sampler_kwargs(self) -> Optional[Dict[str, Any]]: + def distributed_sampler_kwargs(self) -> Optional[dict[str, Any]]: """Arguments for the ``DistributedSampler``. If this method is not defined, or it returns ``None``, then the ``DistributedSampler`` will not be used. diff --git a/src/lightning/fabric/strategies/registry.py b/src/lightning/fabric/strategies/registry.py index d7899584b5a88..d2376463c3111 100644 --- a/src/lightning/fabric/strategies/registry.py +++ b/src/lightning/fabric/strategies/registry.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Optional from typing_extensions import override @@ -65,7 +65,7 @@ def register( if name in self and not override: raise ValueError(f"'{name}' is already present in the registry. HINT: Use `override=True`.") - data: Dict[str, Any] = {} + data: dict[str, Any] = {} data["description"] = description if description is not None else "" data["init_params"] = init_params @@ -104,7 +104,7 @@ def remove(self, name: str) -> None: """Removes the registered strategy by name.""" self.pop(name) - def available_strategies(self) -> List: + def available_strategies(self) -> list: """Returns a list of registered strategies.""" return list(self.keys()) diff --git a/src/lightning/fabric/strategies/strategy.py b/src/lightning/fabric/strategies/strategy.py index 6bfed6a270b68..4daad9b954b2f 100644 --- a/src/lightning/fabric/strategies/strategy.py +++ b/src/lightning/fabric/strategies/strategy.py @@ -13,8 +13,9 @@ # limitations under the License. import logging from abc import ABC, abstractmethod -from contextlib import ExitStack -from typing import Any, Callable, ContextManager, Dict, Iterable, List, Optional, Tuple, TypeVar, Union +from collections.abc import Iterable +from contextlib import AbstractContextManager, ExitStack +from typing import Any, Callable, Optional, TypeVar, Union import torch from torch import Tensor @@ -117,7 +118,7 @@ def process_dataloader(self, dataloader: DataLoader) -> DataLoader: """ return dataloader - def tensor_init_context(self) -> ContextManager: + def tensor_init_context(self) -> AbstractContextManager: """Controls how tensors get created (device, dtype).""" precision_init_ctx = self.precision.tensor_init_context() stack = ExitStack() @@ -125,7 +126,7 @@ def tensor_init_context(self) -> ContextManager: stack.enter_context(precision_init_ctx) return stack - def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager: + def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager: """A context manager wrapping the model instantiation. Here, the strategy can control how the parameters of the model get created (device, dtype) and or apply other @@ -144,8 +145,8 @@ def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManag return stack def setup_module_and_optimizers( - self, module: Module, optimizers: List[Optimizer] - ) -> Tuple[Module, List[Optimizer]]: + self, module: Module, optimizers: list[Optimizer] + ) -> tuple[Module, list[Optimizer]]: """Set up a model and multiple optimizers together. The returned objects are expected to be in the same order they were passed in. The default implementation will @@ -256,9 +257,9 @@ def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool: def save_checkpoint( self, path: _PATH, - state: Dict[str, Union[Module, Optimizer, Any]], + state: dict[str, Union[Module, Optimizer, Any]], storage_options: Optional[Any] = None, - filter: Optional[Dict[str, Callable[[str, Any], bool]]] = None, + filter: Optional[dict[str, Callable[[str, Any], bool]]] = None, ) -> None: """Save model, optimizer, and other state as a checkpoint file. @@ -276,17 +277,17 @@ def save_checkpoint( if self.is_global_zero: self.checkpoint_io.save_checkpoint(checkpoint=state, path=path, storage_options=storage_options) - def get_module_state_dict(self, module: Module) -> Dict[str, Union[Any, Tensor]]: + def get_module_state_dict(self, module: Module) -> dict[str, Union[Any, Tensor]]: """Returns model state.""" return module.state_dict() def load_module_state_dict( - self, module: Module, state_dict: Dict[str, Union[Any, Tensor]], strict: bool = True + self, module: Module, state_dict: dict[str, Union[Any, Tensor]], strict: bool = True ) -> None: """Loads the given state into the model.""" module.load_state_dict(state_dict, strict=strict) - def get_optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: + def get_optimizer_state(self, optimizer: Optimizer) -> dict[str, Tensor]: """Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom plugins. @@ -304,9 +305,9 @@ def get_optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: def load_checkpoint( self, path: _PATH, - state: Optional[Union[Module, Optimizer, Dict[str, Union[Module, Optimizer, Any]]]] = None, + state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None, strict: bool = True, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Load the contents from a checkpoint and restore the state of the given objects. Args: @@ -394,9 +395,9 @@ def _err_msg_joint_setup_required(self) -> str: ) def _convert_stateful_objects_in_state( - self, state: Dict[str, Union[Module, Optimizer, Any]], filter: Dict[str, Callable[[str, Any], bool]] - ) -> Dict[str, Any]: - converted_state: Dict[str, Any] = {} + self, state: dict[str, Union[Module, Optimizer, Any]], filter: dict[str, Callable[[str, Any], bool]] + ) -> dict[str, Any]: + converted_state: dict[str, Any] = {} for key, obj in state.items(): # convert the state if isinstance(obj, Module): @@ -421,7 +422,7 @@ class _BackwardSyncControl(ABC): """ @abstractmethod - def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager: + def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager: """Blocks the synchronization of gradients during the backward pass. This is a context manager. It is only effective if it wraps a call to `.backward()`. @@ -433,7 +434,7 @@ class _Sharded(ABC): """Mixin-interface for any :class:`Strategy` that wants to expose functionality for sharding model parameters.""" @abstractmethod - def module_sharded_context(self) -> ContextManager: + def module_sharded_context(self) -> AbstractContextManager: """A context manager that goes over the instantiation of an :class:`torch.nn.Module` and handles sharding of parameters on creation. @@ -454,7 +455,7 @@ def _validate_keys_for_strict_loading( def _apply_filter( - key: str, filter: Dict[str, Callable[[str, Any], bool]], source_dict: object, target_dict: Dict[str, Any] + key: str, filter: dict[str, Callable[[str, Any], bool]], source_dict: object, target_dict: dict[str, Any] ) -> None: # filter out if necessary if key in filter and isinstance(source_dict, dict): diff --git a/src/lightning/fabric/strategies/xla.py b/src/lightning/fabric/strategies/xla.py index 28d65558cae90..3b2e10e87b0a7 100644 --- a/src/lightning/fabric/strategies/xla.py +++ b/src/lightning/fabric/strategies/xla.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import io -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch from torch import Tensor @@ -43,7 +43,7 @@ class XLAStrategy(ParallelStrategy): def __init__( self, accelerator: Optional[Accelerator] = None, - parallel_devices: Optional[List[torch.device]] = None, + parallel_devices: Optional[list[torch.device]] = None, checkpoint_io: Optional[XLACheckpointIO] = None, precision: Optional[XLAPrecision] = None, sync_module_states: bool = True, @@ -276,9 +276,9 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: def save_checkpoint( self, path: _PATH, - state: Dict[str, Union[Module, Optimizer, Any]], + state: dict[str, Union[Module, Optimizer, Any]], storage_options: Optional[Any] = None, - filter: Optional[Dict[str, Callable[[str, Any], bool]]] = None, + filter: Optional[dict[str, Callable[[str, Any], bool]]] = None, ) -> None: """Save model, optimizer, and other state as a checkpoint file. diff --git a/src/lightning/fabric/strategies/xla_fsdp.py b/src/lightning/fabric/strategies/xla_fsdp.py index 6da693bafb1c8..935ef72713bcc 100644 --- a/src/lightning/fabric/strategies/xla_fsdp.py +++ b/src/lightning/fabric/strategies/xla_fsdp.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import io -from contextlib import ExitStack, nullcontext +from contextlib import AbstractContextManager, ExitStack, nullcontext from functools import partial from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, List, Literal, Optional, Set, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union import torch from torch import Tensor @@ -46,7 +46,7 @@ if TYPE_CHECKING: from torch_xla.distributed.parallel_loader import MpDeviceLoader -_POLICY_SET = Set[Type[Module]] +_POLICY_SET = set[type[Module]] _POLICY = Union[_POLICY_SET, Callable[[Module, bool, int], bool]] @@ -56,7 +56,7 @@ class XLAFSDPStrategy(ParallelStrategy, _Sharded): .. warning:: This is an :ref:`experimental ` feature. - For more information check out https://github.com/pytorch/xla/blob/master/docs/fsdp.md + For more information check out https://github.com/pytorch/xla/blob/v2.5.0/docs/fsdp.md Args: auto_wrap_policy: Same as ``auto_wrap_policy`` parameter in @@ -83,7 +83,7 @@ class XLAFSDPStrategy(ParallelStrategy, _Sharded): def __init__( self, accelerator: Optional[Accelerator] = None, - parallel_devices: Optional[List[torch.device]] = None, + parallel_devices: Optional[list[torch.device]] = None, checkpoint_io: Optional[XLACheckpointIO] = None, precision: Optional[XLAPrecision] = None, auto_wrap_policy: Optional[_POLICY] = None, @@ -196,8 +196,8 @@ def setup_environment(self) -> None: @override def setup_module_and_optimizers( - self, module: Module, optimizers: List[Optimizer] - ) -> Tuple[Module, List[Optimizer]]: + self, module: Module, optimizers: list[Optimizer] + ) -> tuple[Module, list[Optimizer]]: """Returns NotImplementedError since for XLAFSDP optimizer setup must happen after module setup.""" raise NotImplementedError( f"The `{type(self).__name__}` does not support the joint setup of module and optimizer(s)." @@ -225,7 +225,7 @@ def setup_module(self, module: Module) -> Module: def module_to_device(self, module: Module) -> None: pass - def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager: + def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager: precision_init_ctx = self.precision.module_init_context() module_sharded_ctx = self.module_sharded_context() stack = ExitStack() @@ -235,7 +235,7 @@ def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManag return stack @override - def module_sharded_context(self) -> ContextManager: + def module_sharded_context(self) -> AbstractContextManager: return nullcontext() @override @@ -408,9 +408,9 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: def save_checkpoint( self, path: _PATH, - state: Dict[str, Union[Module, Optimizer, Any]], + state: dict[str, Union[Module, Optimizer, Any]], storage_options: Optional[Any] = None, - filter: Optional[Dict[str, Callable[[str, Any], bool]]] = None, + filter: Optional[dict[str, Callable[[str, Any], bool]]] = None, ) -> None: """Save model, optimizer, and other state in the provided checkpoint directory. @@ -483,13 +483,13 @@ def save_checkpoint( def _save_checkpoint_shard( self, path: Path, - state: Dict[str, Union[Module, Optimizer, Any]], + state: dict[str, Union[Module, Optimizer, Any]], storage_options: Optional[Any], - filter: Optional[Dict[str, Callable[[str, Any], bool]]], + filter: Optional[dict[str, Callable[[str, Any], bool]]], ) -> None: from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as XLAFSDP - converted_state: Dict[str, Any] = {} + converted_state: dict[str, Any] = {} for key, obj in state.items(): # convert the state if isinstance(obj, Module) and isinstance(obj, XLAFSDP): @@ -512,9 +512,9 @@ def _save_checkpoint_shard( def load_checkpoint( self, path: _PATH, - state: Optional[Union[Module, Optimizer, Dict[str, Union[Module, Optimizer, Any]]]] = None, + state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None, strict: bool = True, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Given a folder, load the contents from a checkpoint and restore the state of the given objects. The strategy currently only supports saving and loading sharded checkpoints which are stored in form of a @@ -617,7 +617,7 @@ def load_checkpoint( def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: strategy_registry.register("xla_fsdp", cls, description=cls.__name__) - def _parse_fsdp_kwargs(self) -> Dict: + def _parse_fsdp_kwargs(self) -> dict: # this needs to be delayed because `self.precision` isn't available at init kwargs = self._fsdp_kwargs.copy() precision = self.precision @@ -629,7 +629,7 @@ def _parse_fsdp_kwargs(self) -> Dict: return _activation_checkpointing_kwargs(self._activation_checkpointing_policy, kwargs) -def _auto_wrap_policy_kwargs(policy: Optional["_POLICY"], kwargs: Dict) -> Dict: +def _auto_wrap_policy_kwargs(policy: Optional["_POLICY"], kwargs: dict) -> dict: if policy is None: return kwargs if isinstance(policy, set): @@ -649,7 +649,7 @@ def _activation_checkpointing_auto_wrapper(policy: _POLICY_SET, module: Module, return XLAFSDP(module, *args, **kwargs) -def _activation_checkpointing_kwargs(policy: Optional[_POLICY_SET], kwargs: Dict) -> Dict: +def _activation_checkpointing_kwargs(policy: Optional[_POLICY_SET], kwargs: dict) -> dict: if not policy: return kwargs if "auto_wrapper_callable" in kwargs: @@ -668,7 +668,7 @@ def _activation_checkpointing_kwargs(policy: Optional[_POLICY_SET], kwargs: Dict class _XLAFSDPBackwardSyncControl(_BackwardSyncControl): @override - def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager: + def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager: """Blocks gradient synchronization inside the :class:`~torch_xla.distributed.fsdp.XlaFullyShardedDataParallel` wrapper.""" if not enabled: diff --git a/src/lightning/fabric/utilities/apply_func.py b/src/lightning/fabric/utilities/apply_func.py index d43565f494d3c..35693a5fcf1fb 100644 --- a/src/lightning/fabric/utilities/apply_func.py +++ b/src/lightning/fabric/utilities/apply_func.py @@ -15,7 +15,7 @@ from abc import ABC from functools import partial -from typing import TYPE_CHECKING, Any, Callable, List, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Union import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -34,7 +34,7 @@ def _from_numpy(value: "np.ndarray", device: _DEVICE) -> Tensor: return torch.from_numpy(value).to(device) -CONVERSION_DTYPES: List[Tuple[Any, Callable[[Any, Any], Tensor]]] = [ +CONVERSION_DTYPES: list[tuple[Any, Callable[[Any, Any], Tensor]]] = [ # bool -> uint8 as bool -> torch.bool triggers RuntimeError: Unsupported data type for NCCL process group (bool, partial(torch.tensor, dtype=torch.uint8)), (int, partial(torch.tensor, dtype=torch.int)), diff --git a/src/lightning/fabric/utilities/cloud_io.py b/src/lightning/fabric/utilities/cloud_io.py index 7ecc9eea501a6..9d0a33afd0b77 100644 --- a/src/lightning/fabric/utilities/cloud_io.py +++ b/src/lightning/fabric/utilities/cloud_io.py @@ -16,7 +16,7 @@ import io import logging from pathlib import Path -from typing import IO, Any, Dict, Union +from typing import IO, Any, Union import fsspec import fsspec.utils @@ -69,7 +69,7 @@ def get_filesystem(path: _PATH, **kwargs: Any) -> AbstractFileSystem: return fs -def _atomic_save(checkpoint: Dict[str, Any], filepath: Union[str, Path]) -> None: +def _atomic_save(checkpoint: dict[str, Any], filepath: Union[str, Path]) -> None: """Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints. Args: diff --git a/src/lightning/fabric/utilities/data.py b/src/lightning/fabric/utilities/data.py index 1ec0edce38050..ea35d8c3da4a9 100644 --- a/src/lightning/fabric/utilities/data.py +++ b/src/lightning/fabric/utilities/data.py @@ -16,9 +16,10 @@ import inspect import os from collections import OrderedDict +from collections.abc import Generator, Iterable, Sized from contextlib import contextmanager from functools import partial -from typing import Any, Callable, Dict, Generator, Iterable, Optional, Sized, Tuple, Type, Union +from typing import Any, Callable, Optional, Union from lightning_utilities.core.inheritance import get_all_subclasses from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler @@ -79,7 +80,7 @@ def _update_dataloader(dataloader: DataLoader, sampler: Union[Sampler, Iterable] def _get_dataloader_init_args_and_kwargs( dataloader: DataLoader, sampler: Union[Sampler, Iterable], -) -> Tuple[Tuple[Any], Dict[str, Any]]: +) -> tuple[tuple[Any], dict[str, Any]]: if not isinstance(dataloader, DataLoader): raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`") @@ -172,7 +173,7 @@ def _get_dataloader_init_args_and_kwargs( def _dataloader_init_kwargs_resolve_sampler( dataloader: DataLoader, sampler: Union[Sampler, Iterable], -) -> Dict[str, Any]: +) -> dict[str, Any]: """This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its re- instantiation.""" batch_sampler = getattr(dataloader, "batch_sampler") @@ -249,7 +250,7 @@ def _auto_add_worker_init_fn(dataloader: object, rank: int) -> None: dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank) -def _reinstantiate_wrapped_cls(orig_object: Any, *args: Any, explicit_cls: Optional[Type] = None, **kwargs: Any) -> Any: +def _reinstantiate_wrapped_cls(orig_object: Any, *args: Any, explicit_cls: Optional[type] = None, **kwargs: Any) -> Any: constructor = type(orig_object) if explicit_cls is None else explicit_cls try: @@ -355,7 +356,7 @@ def wrapper(obj: Any, *args: Any) -> None: @contextmanager -def _replace_dunder_methods(base_cls: Type, store_explicit_arg: Optional[str] = None) -> Generator[None, None, None]: +def _replace_dunder_methods(base_cls: type, store_explicit_arg: Optional[str] = None) -> Generator[None, None, None]: """This context manager is used to add support for re-instantiation of custom (subclasses) of `base_cls`. It patches the ``__init__``, ``__setattr__`` and ``__delattr__`` methods. @@ -366,8 +367,8 @@ def _replace_dunder_methods(base_cls: Type, store_explicit_arg: Optional[str] = # Check that __init__ belongs to the class # https://stackoverflow.com/a/5253424 if "__init__" in cls.__dict__: - cls.__old__init__ = cls.__init__ - cls.__init__ = _wrap_init_method(cls.__init__, store_explicit_arg) + cls.__old__init__ = cls.__init__ # type: ignore[misc] + cls.__init__ = _wrap_init_method(cls.__init__, store_explicit_arg) # type: ignore[misc] # we want at least one setattr/delattr in the chain to be patched and it can happen, that none of the subclasses # implement `__setattr__`/`__delattr__`. Therefore, we are always patching the `base_cls` @@ -389,11 +390,11 @@ def _replace_dunder_methods(base_cls: Type, store_explicit_arg: Optional[str] = def _replace_value_in_saved_args( replace_key: str, replace_value: Any, - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - default_kwargs: Dict[str, Any], - arg_names: Tuple[str, ...], -) -> Tuple[bool, Tuple[Any, ...], Dict[str, Any]]: + args: tuple[Any, ...], + kwargs: dict[str, Any], + default_kwargs: dict[str, Any], + arg_names: tuple[str, ...], +) -> tuple[bool, tuple[Any, ...], dict[str, Any]]: """Tries to replace an argument value in a saved list of args and kwargs. Returns a tuple indicating success of the operation and modified saved args and kwargs @@ -420,7 +421,7 @@ def _set_sampler_epoch(dataloader: object, epoch: int) -> None: """ # cannot use a set because samplers might be unhashable: use a dict based on the id to drop duplicates - objects: Dict[int, Any] = {} + objects: dict[int, Any] = {} # check dataloader.sampler if (sampler := getattr(dataloader, "sampler", None)) is not None: objects[id(sampler)] = sampler @@ -458,7 +459,7 @@ def _num_cpus_available() -> int: return 1 if cpu_count is None else cpu_count -class AttributeDict(Dict): +class AttributeDict(dict): """A container to store state variables of your program. This is a drop-in replacement for a Python dictionary, with the additional functionality to access and modify keys diff --git a/src/lightning/fabric/utilities/device_dtype_mixin.py b/src/lightning/fabric/utilities/device_dtype_mixin.py index 9f06dc50cfbef..ff5a0949e4207 100644 --- a/src/lightning/fabric/utilities/device_dtype_mixin.py +++ b/src/lightning/fabric/utilities/device_dtype_mixin.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union import torch from torch.nn import Module @@ -20,7 +20,7 @@ class _DeviceDtypeModuleMixin(Module): - __jit_unused_properties__: List[str] = ["device", "dtype"] + __jit_unused_properties__: list[str] = ["device", "dtype"] def __init__(self) -> None: super().__init__() diff --git a/src/lightning/fabric/utilities/device_parser.py b/src/lightning/fabric/utilities/device_parser.py index 16965d944caec..ff5bebd9b4516 100644 --- a/src/lightning/fabric/utilities/device_parser.py +++ b/src/lightning/fabric/utilities/device_parser.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, MutableSequence, Optional, Tuple, Union +from collections.abc import MutableSequence +from typing import Optional, Union import torch @@ -19,7 +20,7 @@ from lightning.fabric.utilities.types import _DEVICE -def _determine_root_gpu_device(gpus: List[_DEVICE]) -> Optional[_DEVICE]: +def _determine_root_gpu_device(gpus: list[_DEVICE]) -> Optional[_DEVICE]: """ Args: gpus: Non-empty list of ints representing which GPUs to use @@ -46,10 +47,10 @@ def _determine_root_gpu_device(gpus: List[_DEVICE]) -> Optional[_DEVICE]: def _parse_gpu_ids( - gpus: Optional[Union[int, str, List[int]]], + gpus: Optional[Union[int, str, list[int]]], include_cuda: bool = False, include_mps: bool = False, -) -> Optional[List[int]]: +) -> Optional[list[int]]: """Parses the GPU IDs given in the format as accepted by the :class:`~lightning.pytorch.trainer.trainer.Trainer`. Args: @@ -102,7 +103,7 @@ def _parse_gpu_ids( return _sanitize_gpu_ids(gpus, include_cuda=include_cuda, include_mps=include_mps) -def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[int, List[int]]: +def _normalize_parse_gpu_string_input(s: Union[int, str, list[int]]) -> Union[int, list[int]]: if not isinstance(s, str): return s if s == "-1": @@ -112,7 +113,7 @@ def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[in return int(s.strip()) -def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps: bool = False) -> List[int]: +def _sanitize_gpu_ids(gpus: list[int], include_cuda: bool = False, include_mps: bool = False) -> list[int]: """Checks that each of the GPUs in the list is actually available. Raises a MisconfigurationException if any of the GPUs is not available. @@ -139,8 +140,8 @@ def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps: def _normalize_parse_gpu_input_to_list( - gpus: Union[int, List[int], Tuple[int, ...]], include_cuda: bool, include_mps: bool -) -> Optional[List[int]]: + gpus: Union[int, list[int], tuple[int, ...]], include_cuda: bool, include_mps: bool +) -> Optional[list[int]]: assert gpus is not None if isinstance(gpus, (MutableSequence, tuple)): return list(gpus) @@ -154,7 +155,7 @@ def _normalize_parse_gpu_input_to_list( return list(range(gpus)) -def _get_all_available_gpus(include_cuda: bool = False, include_mps: bool = False) -> List[int]: +def _get_all_available_gpus(include_cuda: bool = False, include_mps: bool = False) -> list[int]: """ Returns: A list of all available GPUs @@ -167,7 +168,7 @@ def _get_all_available_gpus(include_cuda: bool = False, include_mps: bool = Fals return cuda_gpus + mps_gpus -def _check_unique(device_ids: List[int]) -> None: +def _check_unique(device_ids: list[int]) -> None: """Checks that the device_ids are unique. Args: diff --git a/src/lightning/fabric/utilities/distributed.py b/src/lightning/fabric/utilities/distributed.py index 0e6c52dfb09b9..ec4eb261f2d3e 100644 --- a/src/lightning/fabric/utilities/distributed.py +++ b/src/lightning/fabric/utilities/distributed.py @@ -4,10 +4,11 @@ import os import signal import time +from collections.abc import Iterable, Iterator, Sized from contextlib import nullcontext from datetime import timedelta from pathlib import Path -from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sized, Union +from typing import TYPE_CHECKING, Any, Optional, Union import torch import torch.nn.functional as F @@ -99,7 +100,7 @@ def is_shared_filesystem(strategy: "Strategy", path: Optional[_PATH] = None, tim return all_found -def _gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tensor]: +def _gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> list[Tensor]: """Function to gather all tensors from several DDP processes onto a list that is broadcasted to all processes. Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case @@ -153,7 +154,7 @@ def _gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Ten return gathered_result -def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]: +def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> list[Tensor]: gathered_result = [torch.zeros_like(result) for _ in range(world_size)] torch.distributed.all_gather(gathered_result, result, group) return gathered_result @@ -345,7 +346,7 @@ def __init__(self, sampler: Union[Sampler, Iterable]) -> None: ) self._sampler = sampler # defer materializing an iterator until it is necessary - self._sampler_list: Optional[List[Any]] = None + self._sampler_list: Optional[list[Any]] = None @override def __getitem__(self, index: int) -> Any: diff --git a/src/lightning/fabric/utilities/imports.py b/src/lightning/fabric/utilities/imports.py index 4dbd57e531859..a1c5a6f6dcd1b 100644 --- a/src/lightning/fabric/utilities/imports.py +++ b/src/lightning/fabric/utilities/imports.py @@ -31,7 +31,9 @@ _TORCH_GREATER_EQUAL_2_2 = compare_version("torch", operator.ge, "2.2.0") _TORCH_GREATER_EQUAL_2_3 = compare_version("torch", operator.ge, "2.3.0") +_TORCH_EQUAL_2_4_0 = compare_version("torch", operator.eq, "2.4.0") _TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0") +_TORCH_GREATER_EQUAL_2_4_1 = compare_version("torch", operator.ge, "2.4.1") _PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10) diff --git a/src/lightning/fabric/utilities/init.py b/src/lightning/fabric/utilities/init.py index c92dfd8c2e82b..2760c6bd227c1 100644 --- a/src/lightning/fabric/utilities/init.py +++ b/src/lightning/fabric/utilities/init.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools -from typing import Any, Callable, Dict, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union import torch from torch.nn import Module, Parameter @@ -46,7 +47,7 @@ def __torch_function__( func: Callable, types: Sequence, args: Sequence[Any] = (), - kwargs: Optional[Dict] = None, + kwargs: Optional[dict] = None, ) -> Any: kwargs = kwargs or {} if not self.enabled: diff --git a/src/lightning/fabric/utilities/load.py b/src/lightning/fabric/utilities/load.py index a1c3b6933b2f6..9e158c1677b6d 100644 --- a/src/lightning/fabric/utilities/load.py +++ b/src/lightning/fabric/utilities/load.py @@ -13,10 +13,12 @@ import os import pickle import warnings +from collections import OrderedDict +from collections.abc import Sequence from functools import partial from io import BytesIO from pathlib import Path -from typing import IO, TYPE_CHECKING, Any, Callable, Dict, Optional, OrderedDict, Sequence, Set, Union +from typing import IO, TYPE_CHECKING, Any, Callable, Optional, Union import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -134,7 +136,7 @@ def __torch_function__( func: Callable, types: Sequence, args: Sequence[Any] = (), - kwargs: Optional[Dict] = None, + kwargs: Optional[dict] = None, ) -> Any: kwargs = kwargs or {} loaded_args = [(arg._load_tensor() if isinstance(arg, _NotYetLoadedTensor) else arg) for arg in args] @@ -219,7 +221,7 @@ def _load_tensor(t: _NotYetLoadedTensor) -> Tensor: def _move_state_into( - source: Dict[str, Any], destination: Dict[str, Union[Any, _Stateful]], keys: Optional[Set[str]] = None + source: dict[str, Any], destination: dict[str, Union[Any, _Stateful]], keys: Optional[set[str]] = None ) -> None: """Takes the state from the source destination and moves it into the destination dictionary. @@ -235,7 +237,7 @@ def _move_state_into( destination[key] = state -def _load_distributed_checkpoint(checkpoint_folder: Path) -> Dict[str, Any]: +def _load_distributed_checkpoint(checkpoint_folder: Path) -> dict[str, Any]: """Loads a sharded checkpoint saved with the `torch.distributed.checkpoint` into a full state dict. The current implementation assumes that the entire checkpoint fits in CPU memory. @@ -248,7 +250,7 @@ def _load_distributed_checkpoint(checkpoint_folder: Path) -> Dict[str, Any]: from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner from torch.distributed.checkpoint.state_dict_loader import _load_state_dict - checkpoint: Dict[str, Any] = {} + checkpoint: dict[str, Any] = {} _load_state_dict( checkpoint, storage_reader=FileSystemReader(checkpoint_folder), diff --git a/src/lightning/fabric/utilities/logger.py b/src/lightning/fabric/utilities/logger.py index 07b76ad9b04d8..dd2b0a3663fc9 100644 --- a/src/lightning/fabric/utilities/logger.py +++ b/src/lightning/fabric/utilities/logger.py @@ -15,15 +15,16 @@ import inspect import json from argparse import Namespace +from collections.abc import Mapping, MutableMapping from dataclasses import asdict, is_dataclass -from typing import Any, Dict, Mapping, MutableMapping, Optional, Union +from typing import Any, Optional, Union from torch import Tensor from lightning.fabric.utilities.imports import _NUMPY_AVAILABLE -def _convert_params(params: Optional[Union[Dict[str, Any], Namespace]]) -> Dict[str, Any]: +def _convert_params(params: Optional[Union[dict[str, Any], Namespace]]) -> dict[str, Any]: """Ensure parameters are a dict or convert to dict if necessary. Args: @@ -43,7 +44,7 @@ def _convert_params(params: Optional[Union[Dict[str, Any], Namespace]]) -> Dict[ return params -def _sanitize_callable_params(params: Dict[str, Any]) -> Dict[str, Any]: +def _sanitize_callable_params(params: dict[str, Any]) -> dict[str, Any]: """Sanitize callable params dict, e.g. ``{'a': } -> {'a': 'function_****'}``. Args: @@ -73,7 +74,7 @@ def _sanitize_callable(val: Any) -> Any: return {key: _sanitize_callable(val) for key, val in params.items()} -def _flatten_dict(params: MutableMapping[Any, Any], delimiter: str = "/", parent_key: str = "") -> Dict[str, Any]: +def _flatten_dict(params: MutableMapping[Any, Any], delimiter: str = "/", parent_key: str = "") -> dict[str, Any]: """Flatten hierarchical dict, e.g. ``{'a': {'b': 'c'}} -> {'a/b': 'c'}``. Args: @@ -92,7 +93,7 @@ def _flatten_dict(params: MutableMapping[Any, Any], delimiter: str = "/", parent {'5/a': 123} """ - result: Dict[str, Any] = {} + result: dict[str, Any] = {} for k, v in params.items(): new_key = parent_key + delimiter + str(k) if parent_key else str(k) if is_dataclass(v) and not isinstance(v, type): @@ -107,7 +108,7 @@ def _flatten_dict(params: MutableMapping[Any, Any], delimiter: str = "/", parent return result -def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]: +def _sanitize_params(params: dict[str, Any]) -> dict[str, Any]: """Returns params with non-primitvies converted to strings for logging. >>> import torch @@ -140,7 +141,7 @@ def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]: return params -def _convert_json_serializable(params: Dict[str, Any]) -> Dict[str, Any]: +def _convert_json_serializable(params: dict[str, Any]) -> dict[str, Any]: """Convert non-serializable objects in params to string.""" return {k: str(v) if not _is_json_serializable(v) else v for k, v in params.items()} diff --git a/src/lightning/fabric/utilities/optimizer.py b/src/lightning/fabric/utilities/optimizer.py index 2c57ec9d1f64a..df83f9b1ca542 100644 --- a/src/lightning/fabric/utilities/optimizer.py +++ b/src/lightning/fabric/utilities/optimizer.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import MutableMapping -from typing import Iterable +from collections.abc import Iterable, MutableMapping from torch import Tensor from torch.optim import Optimizer diff --git a/src/lightning/fabric/utilities/registry.py b/src/lightning/fabric/utilities/registry.py index 9ad0f90221429..7d8f6ca17712e 100644 --- a/src/lightning/fabric/utilities/registry.py +++ b/src/lightning/fabric/utilities/registry.py @@ -15,7 +15,7 @@ from importlib.metadata import entry_points from inspect import getmembers, isclass from types import ModuleType -from typing import Any, List, Type, Union +from typing import Any, Union from lightning_utilities import is_overridden @@ -24,7 +24,7 @@ _log = logging.getLogger(__name__) -def _load_external_callbacks(group: str) -> List[Any]: +def _load_external_callbacks(group: str) -> list[Any]: """Collect external callbacks registered through entry points. The entry points are expected to be functions returning a list of callbacks. @@ -40,10 +40,10 @@ def _load_external_callbacks(group: str) -> List[Any]: entry_points(group=group) if _PYTHON_GREATER_EQUAL_3_10_0 else entry_points().get(group, {}) # type: ignore[arg-type] ) - external_callbacks: List[Any] = [] + external_callbacks: list[Any] = [] for factory in factories: callback_factory = factory.load() - callbacks_list: Union[List[Any], Any] = callback_factory() + callbacks_list: Union[list[Any], Any] = callback_factory() callbacks_list = [callbacks_list] if not isinstance(callbacks_list, list) else callbacks_list if callbacks_list: _log.info( @@ -54,7 +54,7 @@ def _load_external_callbacks(group: str) -> List[Any]: return external_callbacks -def _register_classes(registry: Any, method: str, module: ModuleType, parent: Type[object]) -> None: +def _register_classes(registry: Any, method: str, module: ModuleType, parent: type[object]) -> None: for _, member in getmembers(module, isclass): if issubclass(member, parent) and is_overridden(method, member, parent): register_fn = getattr(member, method) diff --git a/src/lightning/fabric/utilities/seed.py b/src/lightning/fabric/utilities/seed.py index a2d627828a77e..f9c0ddeb86cf0 100644 --- a/src/lightning/fabric/utilities/seed.py +++ b/src/lightning/fabric/utilities/seed.py @@ -3,7 +3,7 @@ import random from random import getstate as python_get_rng_state from random import setstate as python_set_rng_state -from typing import Any, Dict, List, Optional +from typing import Any, Optional import torch @@ -104,10 +104,13 @@ def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: if _NUMPY_AVAILABLE: import numpy as np - np.random.seed(seed_sequence[3] & 0xFFFFFFFF) # numpy takes 32-bit seed only + ss = np.random.SeedSequence([base_seed, worker_id, global_rank]) + np_rng_seed = ss.generate_state(4) + np.random.seed(np_rng_seed) -def _generate_seed_sequence(base_seed: int, worker_id: int, global_rank: int, count: int) -> List[int]: + +def _generate_seed_sequence(base_seed: int, worker_id: int, global_rank: int, count: int) -> list[int]: """Generates a sequence of seeds from a base seed, worker id and rank using the linear congruential generator (LCG) algorithm.""" # Combine base seed, worker id and rank into a unique 64-bit number @@ -120,7 +123,7 @@ def _generate_seed_sequence(base_seed: int, worker_id: int, global_rank: int, co return seeds -def _collect_rng_states(include_cuda: bool = True) -> Dict[str, Any]: +def _collect_rng_states(include_cuda: bool = True) -> dict[str, Any]: r"""Collect the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python.""" states = { "torch": torch.get_rng_state(), @@ -135,7 +138,7 @@ def _collect_rng_states(include_cuda: bool = True) -> Dict[str, Any]: return states -def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None: +def _set_rng_states(rng_state_dict: dict[str, Any]) -> None: r"""Set the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python in the current process.""" torch.set_rng_state(rng_state_dict["torch"]) diff --git a/src/lightning/fabric/utilities/spike.py b/src/lightning/fabric/utilities/spike.py index 5dca5990064e8..04c554461c58c 100644 --- a/src/lightning/fabric/utilities/spike.py +++ b/src/lightning/fabric/utilities/spike.py @@ -2,7 +2,7 @@ import operator import os import warnings -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Literal, Optional, Union import torch from lightning_utilities.core.imports import compare_version @@ -66,7 +66,7 @@ def __init__( self.warmup = warmup self.atol = atol self.rtol = rtol - self.bad_batches: List[int] = [] + self.bad_batches: list[int] = [] self.exclude_batches_path = exclude_batches_path self.finite_only = finite_only @@ -147,7 +147,7 @@ def _update_stats(self, val: torch.Tensor) -> None: self.running_mean.update(val) self.last_val = val - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return { "last_val": self.last_val.item() if isinstance(self.last_val, torch.Tensor) else self.last_val, "mode": self.mode, @@ -160,7 +160,7 @@ def state_dict(self) -> Dict[str, Any]: "mean": self.running_mean.base_metric.state_dict(), } - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: self.last_val = state_dict.pop("last_val") self.mode = state_dict.pop("mode") self.warmup = state_dict.pop("warmup") diff --git a/src/lightning/fabric/utilities/testing/_runif.py b/src/lightning/fabric/utilities/testing/_runif.py index 6f0513465cab5..6f5d933f9dae3 100644 --- a/src/lightning/fabric/utilities/testing/_runif.py +++ b/src/lightning/fabric/utilities/testing/_runif.py @@ -14,7 +14,7 @@ import operator import os import sys -from typing import Dict, List, Optional, Tuple +from typing import Optional import torch from lightning_utilities.core.imports import RequirementCache, compare_version @@ -40,7 +40,7 @@ def _runif_reasons( standalone: bool = False, deepspeed: bool = False, dynamo: bool = False, -) -> Tuple[List[str], Dict[str, bool]]: +) -> tuple[list[str], dict[str, bool]]: """Construct reasons for pytest skipif. Args: diff --git a/src/lightning/fabric/utilities/throughput.py b/src/lightning/fabric/utilities/throughput.py index 6743da7b34085..72b33a41f168c 100644 --- a/src/lightning/fabric/utilities/throughput.py +++ b/src/lightning/fabric/utilities/throughput.py @@ -13,7 +13,7 @@ # limitations under the License. # Adapted from https://github.com/mosaicml/composer/blob/f2a2dc820/composer/callbacks/speed_monitor.py from collections import deque -from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union import torch from typing_extensions import override @@ -24,7 +24,7 @@ from lightning.fabric import Fabric from lightning.fabric.plugins import Precision -_THROUGHPUT_METRICS = Dict[str, Union[int, float]] +_THROUGHPUT_METRICS = dict[str, Union[int, float]] # The API design of this class follows `torchmetrics.Metric` but it doesn't need to be an actual Metric because there's @@ -108,7 +108,7 @@ def __init__( self._batches: _MonotonicWindow[int] = _MonotonicWindow(maxlen=window_size) self._samples: _MonotonicWindow[int] = _MonotonicWindow(maxlen=window_size) self._lengths: _MonotonicWindow[int] = _MonotonicWindow(maxlen=window_size) - self._flops: Deque[int] = deque(maxlen=window_size) + self._flops: deque[int] = deque(maxlen=window_size) def update( self, @@ -302,7 +302,7 @@ def measure_flops( return flop_counter.get_total_flops() -_CUDA_FLOPS: Dict[str, Dict[Union[str, torch.dtype], float]] = { +_CUDA_FLOPS: dict[str, dict[Union[str, torch.dtype], float]] = { # Hopper # source: https://resources.nvidia.com/en-us-tensor-core "h100 nvl": { @@ -347,6 +347,14 @@ def measure_flops( torch.int8: 389.9e12, "int4": 779.8e12, }, + "rtx 4080 super": { + torch.float32: 52.2e12, + "tfloat32": 52.2e12, + torch.bfloat16: 52.2e12, + torch.float16: 52.2e12, + torch.int8: 417.6e12, + "int4": 835.2e12, + }, "l4": { torch.float32: 30.3e12, "tfloat32": 60e12, @@ -640,7 +648,7 @@ def _plugin_to_compute_dtype(plugin: "Precision") -> torch.dtype: T = TypeVar("T", bound=float) -class _MonotonicWindow(List[T]): +class _MonotonicWindow(list[T]): """Custom fixed size list that only supports right-append and ensures that all values increase monotonically.""" def __init__(self, maxlen: int) -> None: diff --git a/src/lightning/fabric/utilities/types.py b/src/lightning/fabric/utilities/types.py index 2e18dc89b05b2..1d7235fa36383 100644 --- a/src/lightning/fabric/utilities/types.py +++ b/src/lightning/fabric/utilities/types.py @@ -11,14 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections import defaultdict +from collections.abc import Iterator from pathlib import Path from typing import ( Any, Callable, - DefaultDict, - Dict, - Iterator, - List, Optional, Protocol, TypeVar, @@ -38,7 +36,7 @@ _PATH = Union[str, Path] _DEVICE = Union[torch.device, str, int] _MAP_LOCATION_TYPE = Optional[ - Union[_DEVICE, Callable[[UntypedStorage, str], Optional[UntypedStorage]], Dict[_DEVICE, _DEVICE]] + Union[_DEVICE, Callable[[UntypedStorage, str], Optional[UntypedStorage]], dict[_DEVICE, _DEVICE]] ] _PARAMETERS = Iterator[torch.nn.Parameter] @@ -57,9 +55,9 @@ class _Stateful(Protocol[_DictKey]): """This class is used to detect if an object is stateful using `isinstance(obj, _Stateful)`.""" - def state_dict(self) -> Dict[_DictKey, Any]: ... + def state_dict(self) -> dict[_DictKey, Any]: ... - def load_state_dict(self, state_dict: Dict[_DictKey, Any]) -> None: ... + def load_state_dict(self, state_dict: dict[_DictKey, Any]) -> None: ... @runtime_checkable @@ -86,10 +84,10 @@ def step(self, closure: Optional[Callable[[], float]] = ...) -> Optional[float]: class Optimizable(Steppable, Protocol): """To structurally type ``optimizer``""" - param_groups: List[Dict[Any, Any]] - defaults: Dict[Any, Any] - state: DefaultDict[Tensor, Any] + param_groups: list[dict[Any, Any]] + defaults: dict[Any, Any] + state: defaultdict[Tensor, Any] - def state_dict(self) -> Dict[str, Dict[Any, Any]]: ... + def state_dict(self) -> dict[str, dict[Any, Any]]: ... - def load_state_dict(self, state_dict: Dict[str, Dict[Any, Any]]) -> None: ... + def load_state_dict(self, state_dict: dict[str, dict[Any, Any]]) -> None: ... diff --git a/src/lightning/fabric/utilities/warnings.py b/src/lightning/fabric/utilities/warnings.py index 62e5f5fc2ff17..b62bece384e32 100644 --- a/src/lightning/fabric/utilities/warnings.py +++ b/src/lightning/fabric/utilities/warnings.py @@ -15,7 +15,7 @@ import warnings from pathlib import Path -from typing import Optional, Type, Union +from typing import Optional, Union from lightning.fabric.utilities.rank_zero import LightningDeprecationWarning @@ -38,7 +38,7 @@ def disable_possible_user_warnings(module: str = "") -> None: def _custom_format_warning( - message: Union[Warning, str], category: Type[Warning], filename: str, lineno: int, line: Optional[str] = None + message: Union[Warning, str], category: type[Warning], filename: str, lineno: int, line: Optional[str] = None ) -> str: """Custom formatting that avoids an extra line in case warnings are emitted from the `rank_zero`-functions.""" if _is_path_in_lightning(Path(filename)): diff --git a/src/lightning/fabric/wrappers.py b/src/lightning/fabric/wrappers.py index c57f1974a6bba..b593c9f22ed23 100644 --- a/src/lightning/fabric/wrappers.py +++ b/src/lightning/fabric/wrappers.py @@ -12,19 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +from collections.abc import Generator, Iterator, Mapping from copy import deepcopy from functools import partial, wraps from types import MethodType from typing import ( Any, Callable, - Dict, - Generator, - Iterator, - List, - Mapping, Optional, - Tuple, TypeVar, Union, overload, @@ -48,14 +43,14 @@ from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin from lightning.fabric.utilities.types import Optimizable -T_destination = TypeVar("T_destination", bound=Dict[str, Any]) +T_destination = TypeVar("T_destination", bound=dict[str, Any]) _LIGHTNING_MODULE_STEP_METHODS = ("training_step", "validation_step", "test_step", "predict_step") _in_fabric_backward: bool = False class _FabricOptimizer: - def __init__(self, optimizer: Optimizer, strategy: Strategy, callbacks: Optional[List[Callable]] = None) -> None: + def __init__(self, optimizer: Optimizer, strategy: Strategy, callbacks: Optional[list[Callable]] = None) -> None: """FabricOptimizer is a thin wrapper around the :class:`~torch.optim.Optimizer` that delegates the optimizer step calls to the strategy. @@ -76,10 +71,10 @@ def __init__(self, optimizer: Optimizer, strategy: Strategy, callbacks: Optional def optimizer(self) -> Optimizer: return self._optimizer - def state_dict(self) -> Dict[str, Tensor]: + def state_dict(self) -> dict[str, Tensor]: return self._strategy.get_optimizer_state(self.optimizer) - def load_state_dict(self, state_dict: Dict[str, Tensor]) -> None: + def load_state_dict(self, state_dict: dict[str, Tensor]) -> None: self.optimizer.load_state_dict(state_dict) def step(self, closure: Optional[Callable] = None) -> Any: @@ -149,12 +144,12 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination: ... @overload - def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]: ... + def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> dict[str, Any]: ... @override def state_dict( self, destination: Optional[T_destination] = None, prefix: str = "", keep_vars: bool = False - ) -> Optional[Dict[str, Any]]: + ) -> Optional[dict[str, Any]]: return self._original_module.state_dict( destination=destination, # type: ignore[type-var] prefix=prefix, @@ -350,7 +345,7 @@ def _unwrap( return apply_to_collection(collection, dtype=tuple(types), function=_unwrap) -def _unwrap_compiled(obj: Union[Any, OptimizedModule]) -> Tuple[Union[Any, nn.Module], Optional[Dict[str, Any]]]: +def _unwrap_compiled(obj: Union[Any, OptimizedModule]) -> tuple[Union[Any, nn.Module], Optional[dict[str, Any]]]: """Removes the :class:`torch._dynamo.OptimizedModule` around the object if it is wrapped. Use this function before instance checks against e.g. :class:`_FabricModule`. @@ -366,7 +361,7 @@ def _unwrap_compiled(obj: Union[Any, OptimizedModule]) -> Tuple[Union[Any, nn.Mo return obj, None -def _to_compiled(module: nn.Module, compile_kwargs: Dict[str, Any]) -> OptimizedModule: +def _to_compiled(module: nn.Module, compile_kwargs: dict[str, Any]) -> OptimizedModule: return torch.compile(module, **compile_kwargs) # type: ignore[return-value] diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 1b251b8fb06fa..9c546c42915eb 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -4,6 +4,27 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## [unreleased] - YYYY-MM-DD + +### Changed + +- CometML logger was updated to support the recent Comet SDK ([#20275](https://github.com/Lightning-AI/pytorch-lightning/pull/20275)) + + +## [unreleased] - YYYY-MM-DD + +### Added + +### Changed + +- Merging of hparams when logging now ignores parameter names that begin with underscore `_` ([#20221](https://github.com/Lightning-AI/pytorch-lightning/pull/20221)) + +### Removed + +### Fixed + +- Fix LightningCLI failing when both module and data module save hyperparameters due to conflicting internal `_class_path` parameter ([#20221](https://github.com/Lightning-AI/pytorch-lightning/pull/20221)) + ## [2.4.0] - 2024-08-06 @@ -35,6 +56,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `_LoggerConnector`'s `_ResultMetric` to move all registered keys to the device of the logged value if needed ([#19814](https://github.com/Lightning-AI/pytorch-lightning/issues/19814)) - Fixed `_optimizer_to_device` logic for special 'step' key in optimizer state causing performance regression ([#20019](https://github.com/Lightning-AI/lightning/pull/20019)) - Fixed parameter counts in `ModelSummary` when model has distributed parameters (DTensor) ([#20163](https://github.com/Lightning-AI/pytorch-lightning/pull/20163)) +- Fixed PyTorch Lightning FSDP takes more memory than PyTorch FSDP ([#20323](https://github.com/Lightning-AI/pytorch-lightning/pull/20323)) ## [2.3.0] - 2024-06-13 diff --git a/src/lightning/pytorch/accelerators/accelerator.py b/src/lightning/pytorch/accelerators/accelerator.py index 0490c2d86431c..9238071178a80 100644 --- a/src/lightning/pytorch/accelerators/accelerator.py +++ b/src/lightning/pytorch/accelerators/accelerator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC -from typing import Any, Dict +from typing import Any import lightning.pytorch as pl from lightning.fabric.accelerators.accelerator import Accelerator as _Accelerator @@ -34,7 +34,7 @@ def setup(self, trainer: "pl.Trainer") -> None: """ - def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: + def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: """Get stats for a given device. Args: diff --git a/src/lightning/pytorch/accelerators/cpu.py b/src/lightning/pytorch/accelerators/cpu.py index 735312b363d11..525071cbb377f 100644 --- a/src/lightning/pytorch/accelerators/cpu.py +++ b/src/lightning/pytorch/accelerators/cpu.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Union +from typing import Any, Union import torch from lightning_utilities.core.imports import RequirementCache @@ -38,7 +38,7 @@ def setup_device(self, device: torch.device) -> None: raise MisconfigurationException(f"Device should be CPU, got {device} instead.") @override - def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: + def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: """Get CPU stats from ``psutil`` package.""" return get_cpu_stats() @@ -48,13 +48,13 @@ def teardown(self) -> None: @staticmethod @override - def parse_devices(devices: Union[int, str, List[int]]) -> int: + def parse_devices(devices: Union[int, str]) -> int: """Accelerator device parsing logic.""" return _parse_cpu_cores(devices) @staticmethod @override - def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]: + def get_parallel_devices(devices: Union[int, str]) -> list[torch.device]: """Gets parallel devices for the Accelerator.""" devices = _parse_cpu_cores(devices) return [torch.device("cpu")] * devices @@ -89,7 +89,7 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No _PSUTIL_AVAILABLE = RequirementCache("psutil") -def get_cpu_stats() -> Dict[str, float]: +def get_cpu_stats() -> dict[str, float]: if not _PSUTIL_AVAILABLE: raise ModuleNotFoundError( f"Fetching CPU device stats requires `psutil` to be installed. {str(_PSUTIL_AVAILABLE)}" diff --git a/src/lightning/pytorch/accelerators/cuda.py b/src/lightning/pytorch/accelerators/cuda.py index 6df3bc6b468ee..a00b12a85a8dd 100644 --- a/src/lightning/pytorch/accelerators/cuda.py +++ b/src/lightning/pytorch/accelerators/cuda.py @@ -15,7 +15,7 @@ import os import shutil import subprocess -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union import torch from typing_extensions import override @@ -61,7 +61,7 @@ def set_nvidia_flags(local_rank: int) -> None: _log.info(f"LOCAL_RANK: {local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]") @override - def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: + def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: """Gets stats for the given GPU device. Args: @@ -83,13 +83,13 @@ def teardown(self) -> None: @staticmethod @override - def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: + def parse_devices(devices: Union[int, str, list[int]]) -> Optional[list[int]]: """Accelerator device parsing logic.""" return _parse_gpu_ids(devices, include_cuda=True) @staticmethod @override - def get_parallel_devices(devices: List[int]) -> List[torch.device]: + def get_parallel_devices(devices: list[int]) -> list[torch.device]: """Gets parallel devices for the Accelerator.""" return [torch.device("cuda", i) for i in devices] @@ -114,7 +114,7 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No ) -def get_nvidia_gpu_stats(device: _DEVICE) -> Dict[str, float]: # pragma: no-cover +def get_nvidia_gpu_stats(device: _DEVICE) -> dict[str, float]: # pragma: no-cover """Get GPU stats including memory, fan speed, and temperature from nvidia-smi. Args: diff --git a/src/lightning/pytorch/accelerators/mps.py b/src/lightning/pytorch/accelerators/mps.py index 6efe6292de624..f7674989cc721 100644 --- a/src/lightning/pytorch/accelerators/mps.py +++ b/src/lightning/pytorch/accelerators/mps.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union import torch from typing_extensions import override @@ -43,7 +43,7 @@ def setup_device(self, device: torch.device) -> None: raise MisconfigurationException(f"Device should be MPS, got {device} instead.") @override - def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: + def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: """Get M1 (cpu + gpu) stats from ``psutil`` package.""" return get_device_stats() @@ -53,13 +53,13 @@ def teardown(self) -> None: @staticmethod @override - def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: + def parse_devices(devices: Union[int, str, list[int]]) -> Optional[list[int]]: """Accelerator device parsing logic.""" return _parse_gpu_ids(devices, include_mps=True) @staticmethod @override - def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]: + def get_parallel_devices(devices: Union[int, str, list[int]]) -> list[torch.device]: """Gets parallel devices for the Accelerator.""" parsed_devices = MPSAccelerator.parse_devices(devices) assert parsed_devices is not None @@ -94,7 +94,7 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No _SWAP_PERCENT = "M1_swap_percent" -def get_device_stats() -> Dict[str, float]: +def get_device_stats() -> dict[str, float]: if not _PSUTIL_AVAILABLE: raise ModuleNotFoundError( f"Fetching MPS device stats requires `psutil` to be installed. {str(_PSUTIL_AVAILABLE)}" diff --git a/src/lightning/pytorch/accelerators/xla.py b/src/lightning/pytorch/accelerators/xla.py index 01ef7223efcef..10726b505448c 100644 --- a/src/lightning/pytorch/accelerators/xla.py +++ b/src/lightning/pytorch/accelerators/xla.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any from typing_extensions import override @@ -29,7 +29,7 @@ class XLAAccelerator(Accelerator, FabricXLAAccelerator): """ @override - def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: + def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: """Gets stats for the given XLA device. Args: diff --git a/src/lightning/pytorch/callbacks/callback.py b/src/lightning/pytorch/callbacks/callback.py index 9311d49dcd804..3bfb609465a83 100644 --- a/src/lightning/pytorch/callbacks/callback.py +++ b/src/lightning/pytorch/callbacks/callback.py @@ -13,7 +13,7 @@ # limitations under the License. r"""Base class used to build new callbacks.""" -from typing import Any, Dict, Type +from typing import Any from torch import Tensor from torch.optim import Optimizer @@ -41,7 +41,7 @@ def state_key(self) -> str: return self.__class__.__qualname__ @property - def _legacy_state_key(self) -> Type["Callback"]: + def _legacy_state_key(self) -> type["Callback"]: """State key for checkpoints saved prior to version 1.5.0.""" return type(self) @@ -229,7 +229,7 @@ def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None: """Called when any trainer execution is interrupted by an exception.""" - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: """Called when saving a checkpoint, implement to generate callback's ``state_dict``. Returns: @@ -238,7 +238,7 @@ def state_dict(self) -> Dict[str, Any]: """ return {} - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: """Called when loading a checkpoint, implement to reload callback state given callback's ``state_dict``. Args: @@ -248,7 +248,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: pass def on_save_checkpoint( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: dict[str, Any] ) -> None: r"""Called when saving a checkpoint to give you a chance to store anything else you might want to save. @@ -260,7 +260,7 @@ def on_save_checkpoint( """ def on_load_checkpoint( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: dict[str, Any] ) -> None: r"""Called when loading a model checkpoint, use to reload state. diff --git a/src/lightning/pytorch/callbacks/device_stats_monitor.py b/src/lightning/pytorch/callbacks/device_stats_monitor.py index 64ea47d2897ed..6279dd13be4af 100644 --- a/src/lightning/pytorch/callbacks/device_stats_monitor.py +++ b/src/lightning/pytorch/callbacks/device_stats_monitor.py @@ -19,7 +19,7 @@ """ -from typing import Any, Dict, Optional +from typing import Any, Optional from typing_extensions import override @@ -158,5 +158,5 @@ def on_test_batch_end( self._get_and_log_device_stats(trainer, "on_test_batch_end") -def _prefix_metric_keys(metrics_dict: Dict[str, float], prefix: str, separator: str) -> Dict[str, float]: +def _prefix_metric_keys(metrics_dict: dict[str, float], prefix: str, separator: str) -> dict[str, float]: return {prefix + separator + k: v for k, v in metrics_dict.items()} diff --git a/src/lightning/pytorch/callbacks/early_stopping.py b/src/lightning/pytorch/callbacks/early_stopping.py index d1212fe8cc2e7..78c4215f9ce23 100644 --- a/src/lightning/pytorch/callbacks/early_stopping.py +++ b/src/lightning/pytorch/callbacks/early_stopping.py @@ -20,7 +20,7 @@ """ import logging -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Optional import torch from torch import Tensor @@ -139,7 +139,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s # validation, then we run after validation instead of on train epoch end self._check_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1 - def _validate_condition_metric(self, logs: Dict[str, Tensor]) -> bool: + def _validate_condition_metric(self, logs: dict[str, Tensor]) -> bool: monitor_val = logs.get(self.monitor) error_msg = ( @@ -163,7 +163,7 @@ def monitor_op(self) -> Callable: return self.mode_dict[self.mode] @override - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return { "wait_count": self.wait_count, "stopped_epoch": self.stopped_epoch, @@ -172,7 +172,7 @@ def state_dict(self) -> Dict[str, Any]: } @override - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: self.wait_count = state_dict["wait_count"] self.stopped_epoch = state_dict["stopped_epoch"] self.best_score = state_dict["best_score"] @@ -215,7 +215,7 @@ def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None: if reason and self.verbose: self._log_info(trainer, reason, self.log_rank_zero_only) - def _evaluate_stopping_criteria(self, current: Tensor) -> Tuple[bool, Optional[str]]: + def _evaluate_stopping_criteria(self, current: Tensor) -> tuple[bool, Optional[str]]: should_stop = False reason = None if self.check_finite and not torch.isfinite(current): diff --git a/src/lightning/pytorch/callbacks/finetuning.py b/src/lightning/pytorch/callbacks/finetuning.py index 46a90986c091c..356ab221777ae 100644 --- a/src/lightning/pytorch/callbacks/finetuning.py +++ b/src/lightning/pytorch/callbacks/finetuning.py @@ -19,7 +19,8 @@ """ import logging -from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union +from collections.abc import Generator, Iterable +from typing import Any, Callable, Optional, Union import torch from torch.nn import Module, ModuleDict @@ -85,17 +86,17 @@ class BaseFinetuning(Callback): """ def __init__(self) -> None: - self._internal_optimizer_metadata: Dict[int, List[Dict[str, Any]]] = {} + self._internal_optimizer_metadata: dict[int, list[dict[str, Any]]] = {} self._restarting = False @override - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return { "internal_optimizer_metadata": self._internal_optimizer_metadata, } @override - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: self._restarting = True if "internal_optimizer_metadata" in state_dict: # noqa: SIM401 self._internal_optimizer_metadata = state_dict["internal_optimizer_metadata"] @@ -116,7 +117,7 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - self._restarting = False @staticmethod - def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> List[Module]: + def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> list[Module]: """This function is used to flatten a module or an iterable of modules into a list of its leaf modules (modules with no children) and parent modules that have parameters directly themselves. @@ -215,7 +216,7 @@ def freeze(modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn: BaseFinetuning.freeze_module(mod) @staticmethod - def filter_on_optimizer(optimizer: Optimizer, params: Iterable) -> List: + def filter_on_optimizer(optimizer: Optimizer, params: Iterable) -> list: """This function is used to exclude any parameter which already exists in this optimizer. Args: @@ -285,7 +286,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s ) @staticmethod - def _apply_mapping_to_param_groups(param_groups: List[Dict[str, Any]], mapping: dict) -> List[Dict[str, Any]]: + def _apply_mapping_to_param_groups(param_groups: list[dict[str, Any]], mapping: dict) -> list[dict[str, Any]]: output = [] for g in param_groups: # skip params to save memory @@ -299,7 +300,7 @@ def _store( pl_module: "pl.LightningModule", opt_idx: int, num_param_groups: int, - current_param_groups: List[Dict[str, Any]], + current_param_groups: list[dict[str, Any]], ) -> None: mapping = {p: n for n, p in pl_module.named_parameters()} if opt_idx not in self._internal_optimizer_metadata: @@ -387,14 +388,14 @@ def __init__( self.previous_backbone_lr: Optional[float] = None @override - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return { "internal_optimizer_metadata": self._internal_optimizer_metadata, "previous_backbone_lr": self.previous_backbone_lr, } @override - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: self.previous_backbone_lr = state_dict["previous_backbone_lr"] super().load_state_dict(state_dict) diff --git a/src/lightning/pytorch/callbacks/gradient_accumulation_scheduler.py b/src/lightning/pytorch/callbacks/gradient_accumulation_scheduler.py index 20b1df29d18f5..50ddc10f0f661 100644 --- a/src/lightning/pytorch/callbacks/gradient_accumulation_scheduler.py +++ b/src/lightning/pytorch/callbacks/gradient_accumulation_scheduler.py @@ -20,7 +20,7 @@ """ -from typing import Any, Dict +from typing import Any from typing_extensions import override @@ -64,7 +64,7 @@ class GradientAccumulationScheduler(Callback): """ - def __init__(self, scheduling: Dict[int, int]): + def __init__(self, scheduling: dict[int, int]): super().__init__() if not scheduling: # empty dict error diff --git a/src/lightning/pytorch/callbacks/lr_monitor.py b/src/lightning/pytorch/callbacks/lr_monitor.py index 6a94c7ece70a3..ca2b4a866ee50 100644 --- a/src/lightning/pytorch/callbacks/lr_monitor.py +++ b/src/lightning/pytorch/callbacks/lr_monitor.py @@ -22,7 +22,7 @@ import itertools from collections import defaultdict -from typing import Any, DefaultDict, Dict, List, Literal, Optional, Set, Tuple, Type +from typing import Any, Literal, Optional import torch from torch.optim.optimizer import Optimizer @@ -104,9 +104,9 @@ def __init__( self.log_momentum = log_momentum self.log_weight_decay = log_weight_decay - self.lrs: Dict[str, List[float]] = {} - self.last_momentum_values: Dict[str, Optional[List[float]]] = {} - self.last_weight_decay_values: Dict[str, Optional[List[float]]] = {} + self.lrs: dict[str, list[float]] = {} + self.last_momentum_values: dict[str, Optional[list[float]]] = {} + self.last_weight_decay_values: dict[str, Optional[list[float]]] = {} @override def on_train_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: @@ -141,7 +141,7 @@ def _check_no_key(key: str) -> bool: ) # Find names for schedulers - names: List[List[str]] = [] + names: list[list[str]] = [] ( sched_hparam_keys, optimizers_with_scheduler, @@ -186,7 +186,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) for logger in trainer.loggers: logger.log_metrics(latest_stat, step=trainer.fit_loop.epoch_loop._batches_that_stepped) - def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, float]: + def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> dict[str, float]: latest_stat = {} ( @@ -219,7 +219,7 @@ def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, floa return latest_stat - def _get_optimizer_stats(self, optimizer: Optimizer, names: List[str]) -> Dict[str, float]: + def _get_optimizer_stats(self, optimizer: Optimizer, names: list[str]) -> dict[str, float]: stats = {} param_groups = optimizer.param_groups use_betas = "betas" in optimizer.defaults @@ -236,12 +236,12 @@ def _get_optimizer_stats(self, optimizer: Optimizer, names: List[str]) -> Dict[s return stats - def _extract_lr(self, param_group: Dict[str, Any], name: str) -> Dict[str, Any]: + def _extract_lr(self, param_group: dict[str, Any], name: str) -> dict[str, Any]: lr = param_group["lr"] self.lrs[name].append(lr) return {name: lr} - def _remap_keys(self, names: List[List[str]], token: str = "/pg1") -> None: + def _remap_keys(self, names: list[list[str]], token: str = "/pg1") -> None: """This function is used the remap the keys if param groups for a given optimizer increased.""" for group_new_names in names: for new_name in group_new_names: @@ -251,7 +251,7 @@ def _remap_keys(self, names: List[List[str]], token: str = "/pg1") -> None: elif new_name not in self.lrs: self.lrs[new_name] = [] - def _extract_momentum(self, param_group: Dict[str, List], name: str, use_betas: bool) -> Dict[str, float]: + def _extract_momentum(self, param_group: dict[str, list], name: str, use_betas: bool) -> dict[str, float]: if not self.log_momentum: return {} @@ -259,7 +259,7 @@ def _extract_momentum(self, param_group: Dict[str, List], name: str, use_betas: self.last_momentum_values[name] = momentum return {name: momentum} - def _extract_weight_decay(self, param_group: Dict[str, Any], name: str) -> Dict[str, Any]: + def _extract_weight_decay(self, param_group: dict[str, Any], name: str) -> dict[str, Any]: """Extracts the weight decay statistics from a parameter group.""" if not self.log_weight_decay: return {} @@ -269,14 +269,14 @@ def _extract_weight_decay(self, param_group: Dict[str, Any], name: str) -> Dict[ return {name: weight_decay} def _add_prefix( - self, name: str, optimizer_cls: Type[Optimizer], seen_optimizer_types: DefaultDict[Type[Optimizer], int] + self, name: str, optimizer_cls: type[Optimizer], seen_optimizer_types: defaultdict[type[Optimizer], int] ) -> str: if optimizer_cls not in seen_optimizer_types: return name count = seen_optimizer_types[optimizer_cls] return name + f"-{count - 1}" if count > 1 else name - def _add_suffix(self, name: str, param_groups: List[Dict], param_group_index: int, use_names: bool = True) -> str: + def _add_suffix(self, name: str, param_groups: list[dict], param_group_index: int, use_names: bool = True) -> str: if len(param_groups) > 1: if not use_names: return f"{name}/pg{param_group_index + 1}" @@ -287,7 +287,7 @@ def _add_suffix(self, name: str, param_groups: List[Dict], param_group_index: in return f"{name}/{pg_name}" if pg_name else name return name - def _duplicate_param_group_names(self, param_groups: List[Dict]) -> Set[str]: + def _duplicate_param_group_names(self, param_groups: list[dict]) -> set[str]: names = [pg.get("name", f"pg{i}") for i, pg in enumerate(param_groups, start=1)] unique = set(names) if len(names) == len(unique): @@ -296,13 +296,13 @@ def _duplicate_param_group_names(self, param_groups: List[Dict]) -> Set[str]: def _find_names_from_schedulers( self, - lr_scheduler_configs: List[LRSchedulerConfig], - ) -> Tuple[List[List[str]], List[Optimizer], DefaultDict[Type[Optimizer], int]]: + lr_scheduler_configs: list[LRSchedulerConfig], + ) -> tuple[list[list[str]], list[Optimizer], defaultdict[type[Optimizer], int]]: # Create unique names in the case we have multiple of the same learning # rate scheduler + multiple parameter groups names = [] - seen_optimizers: List[Optimizer] = [] - seen_optimizer_types: DefaultDict[Type[Optimizer], int] = defaultdict(int) + seen_optimizers: list[Optimizer] = [] + seen_optimizer_types: defaultdict[type[Optimizer], int] = defaultdict(int) for config in lr_scheduler_configs: sch = config.scheduler name = config.name if config.name is not None else "lr-" + sch.optimizer.__class__.__name__ @@ -316,10 +316,10 @@ def _find_names_from_schedulers( def _find_names_from_optimizers( self, - optimizers: List[Any], - seen_optimizers: List[Optimizer], - seen_optimizer_types: DefaultDict[Type[Optimizer], int], - ) -> Tuple[List[List[str]], List[Optimizer]]: + optimizers: list[Any], + seen_optimizers: list[Optimizer], + seen_optimizer_types: defaultdict[type[Optimizer], int], + ) -> tuple[list[list[str]], list[Optimizer]]: names = [] optimizers_without_scheduler = [] @@ -342,10 +342,10 @@ def _check_duplicates_and_update_name( self, optimizer: Optimizer, name: str, - seen_optimizers: List[Optimizer], - seen_optimizer_types: DefaultDict[Type[Optimizer], int], + seen_optimizers: list[Optimizer], + seen_optimizer_types: defaultdict[type[Optimizer], int], lr_scheduler_config: Optional[LRSchedulerConfig], - ) -> List[str]: + ) -> list[str]: seen_optimizers.append(optimizer) optimizer_cls = type(optimizer) if lr_scheduler_config is None or lr_scheduler_config.name is None: diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 9587da0f4600b..85bfb65c0ea6e 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -27,7 +27,7 @@ from copy import deepcopy from datetime import timedelta from pathlib import Path -from typing import Any, Dict, Literal, Optional, Set, Union +from typing import Any, Literal, Optional, Union from weakref import proxy import torch @@ -241,7 +241,7 @@ def __init__( self._last_global_step_saved = 0 # no need to save when no steps were taken self._last_time_checked: Optional[float] = None self.current_score: Optional[Tensor] = None - self.best_k_models: Dict[str, Tensor] = {} + self.best_k_models: dict[str, Tensor] = {} self.kth_best_model_path = "" self.best_model_score: Optional[Tensor] = None self.best_model_path = "" @@ -335,7 +335,7 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul self._save_last_checkpoint(trainer, monitor_candidates) @override - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return { "monitor": self.monitor, "best_model_score": self.best_model_score, @@ -349,7 +349,7 @@ def state_dict(self) -> Dict[str, Any]: } @override - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: dirpath_from_ckpt = state_dict.get("dirpath", self.dirpath) if self.dirpath == dirpath_from_ckpt: @@ -367,7 +367,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self.best_model_path = state_dict["best_model_path"] - def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None: + def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: dict[str, Tensor]) -> None: if self.save_top_k == 0: return @@ -533,7 +533,7 @@ def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[Tensor] = def _format_checkpoint_name( self, filename: Optional[str], - metrics: Dict[str, Tensor], + metrics: dict[str, Tensor], prefix: str = "", auto_insert_metric_name: bool = True, ) -> str: @@ -567,7 +567,7 @@ def _format_checkpoint_name( return filename def format_checkpoint_name( - self, metrics: Dict[str, Tensor], filename: Optional[str] = None, ver: Optional[int] = None + self, metrics: dict[str, Tensor], filename: Optional[str] = None, ver: Optional[int] = None ) -> str: """Generate a filename according to the defined template. @@ -637,7 +637,7 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> _PATH: return ckpt_path - def _find_last_checkpoints(self, trainer: "pl.Trainer") -> Set[str]: + def _find_last_checkpoints(self, trainer: "pl.Trainer") -> set[str]: # find all checkpoints in the folder ckpt_path = self.__resolve_ckpt_dir(trainer) last_pattern = rf"^{self.CHECKPOINT_NAME_LAST}(-(\d+))?" @@ -654,7 +654,7 @@ def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None: rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.") def _get_metric_interpolated_filepath_name( - self, monitor_candidates: Dict[str, Tensor], trainer: "pl.Trainer", del_filepath: Optional[str] = None + self, monitor_candidates: dict[str, Tensor], trainer: "pl.Trainer", del_filepath: Optional[str] = None ) -> str: filepath = self.format_checkpoint_name(monitor_candidates) @@ -666,7 +666,7 @@ def _get_metric_interpolated_filepath_name( return filepath - def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, Tensor]: + def _monitor_candidates(self, trainer: "pl.Trainer") -> dict[str, Tensor]: monitor_candidates = deepcopy(trainer.callback_metrics) # cast to int if necessary because `self.log("epoch", 123)` will convert it to float. if it's not a tensor # or does not exist we overwrite it as it's likely an error @@ -676,7 +676,7 @@ def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, Tensor]: monitor_candidates["step"] = step.int() if isinstance(step, Tensor) else torch.tensor(trainer.global_step) return monitor_candidates - def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None: + def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: dict[str, Tensor]) -> None: if not self.save_last: return @@ -697,7 +697,7 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[ if previous and self._should_remove_checkpoint(trainer, previous, filepath): self._remove_checkpoint(trainer, previous) - def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None: + def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: dict[str, Tensor]) -> None: assert self.monitor current = monitor_candidates.get(self.monitor) if self.check_monitor_top_k(trainer, current): @@ -708,7 +708,7 @@ def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Di step = monitor_candidates["step"] rank_zero_info(f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} was not in top {self.save_top_k}") - def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None: + def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: dict[str, Tensor]) -> None: filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer, self.best_model_path) # set the best model path before saving because it will be part of the state. previous, self.best_model_path = self.best_model_path, filepath @@ -718,7 +718,7 @@ def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidate self._remove_checkpoint(trainer, previous) def _update_best_and_save( - self, current: Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor] + self, current: Tensor, trainer: "pl.Trainer", monitor_candidates: dict[str, Tensor] ) -> None: k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k diff --git a/src/lightning/pytorch/callbacks/model_summary.py b/src/lightning/pytorch/callbacks/model_summary.py index 89c31b2cc65e8..03f50d65bf1e9 100644 --- a/src/lightning/pytorch/callbacks/model_summary.py +++ b/src/lightning/pytorch/callbacks/model_summary.py @@ -23,7 +23,7 @@ """ import logging -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Union from typing_extensions import override @@ -54,7 +54,7 @@ class ModelSummary(Callback): def __init__(self, max_depth: int = 1, **summarize_kwargs: Any) -> None: self._max_depth: int = max_depth - self._summarize_kwargs: Dict[str, Any] = summarize_kwargs + self._summarize_kwargs: dict[str, Any] = summarize_kwargs @override def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: @@ -87,11 +87,11 @@ def _summary(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Un @staticmethod def summarize( - summary_data: List[Tuple[str, List[str]]], + summary_data: list[tuple[str, list[str]]], total_parameters: int, trainable_parameters: int, model_size: float, - total_training_modes: Dict[str, int], + total_training_modes: dict[str, int], **summarize_kwargs: Any, ) -> None: summary_table = _format_summary_table( diff --git a/src/lightning/pytorch/callbacks/prediction_writer.py b/src/lightning/pytorch/callbacks/prediction_writer.py index 7f782fb81c091..ce6342c7aa88d 100644 --- a/src/lightning/pytorch/callbacks/prediction_writer.py +++ b/src/lightning/pytorch/callbacks/prediction_writer.py @@ -18,7 +18,8 @@ Aids in saving predictions """ -from typing import Any, Literal, Optional, Sequence +from collections.abc import Sequence +from typing import Any, Literal, Optional from typing_extensions import override diff --git a/src/lightning/pytorch/callbacks/progress/progress_bar.py b/src/lightning/pytorch/callbacks/progress/progress_bar.py index 785bf65af4361..7cf6993b4414b 100644 --- a/src/lightning/pytorch/callbacks/progress/progress_bar.py +++ b/src/lightning/pytorch/callbacks/progress/progress_bar.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from typing_extensions import override @@ -176,7 +176,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s def get_metrics( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" - ) -> Dict[str, Union[int, str, float, Dict[str, float]]]: + ) -> dict[str, Union[int, str, float, dict[str, float]]]: r"""Combines progress bar metrics collected from the trainer with standard metrics from get_standard_metrics. Implement this to override the items displayed in the progress bar. @@ -207,7 +207,7 @@ def get_metrics(self, trainer, model): return {**standard_metrics, **pbar_metrics} -def get_standard_metrics(trainer: "pl.Trainer") -> Dict[str, Union[int, str]]: +def get_standard_metrics(trainer: "pl.Trainer") -> dict[str, Union[int, str]]: r"""Returns the standard metrics displayed in the progress bar. Currently, it only includes the version of the experiment when using a logger. @@ -219,7 +219,7 @@ def get_standard_metrics(trainer: "pl.Trainer") -> Dict[str, Union[int, str]]: Dictionary with the standard metrics to be displayed in the progress bar. """ - items_dict: Dict[str, Union[int, str]] = {} + items_dict: dict[str, Union[int, str]] = {} if trainer.loggers: from lightning.pytorch.loggers.utilities import _version diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 497e96e11b9c4..0a51d99ccb676 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +from collections.abc import Generator from dataclasses import dataclass from datetime import timedelta -from typing import Any, Dict, Generator, Optional, Union, cast +from typing import Any, Optional, Union, cast from lightning_utilities.core.imports import RequirementCache from typing_extensions import override @@ -146,15 +147,15 @@ def __init__( metrics_format: str, ): self._trainer = trainer - self._tasks: Dict[Union[int, TaskID], Any] = {} + self._tasks: dict[Union[int, TaskID], Any] = {} self._current_task_id = 0 - self._metrics: Dict[Union[str, Style], Any] = {} + self._metrics: dict[Union[str, Style], Any] = {} self._style = style self._text_delimiter = text_delimiter self._metrics_format = metrics_format super().__init__() - def update(self, metrics: Dict[Any, Any]) -> None: + def update(self, metrics: dict[Any, Any]) -> None: # Called when metrics are ready to be rendered. # This is to prevent render from causing deadlock issues by requesting metrics # in separate threads. @@ -206,14 +207,14 @@ class RichProgressBarTheme: """ - description: Union[str, "Style"] = "white" + description: Union[str, "Style"] = "" progress_bar: Union[str, "Style"] = "#6206E0" progress_bar_finished: Union[str, "Style"] = "#6206E0" progress_bar_pulse: Union[str, "Style"] = "#6206E0" - batch_progress: Union[str, "Style"] = "white" - time: Union[str, "Style"] = "grey54" - processing_speed: Union[str, "Style"] = "grey70" - metrics: Union[str, "Style"] = "white" + batch_progress: Union[str, "Style"] = "" + time: Union[str, "Style"] = "dim" + processing_speed: Union[str, "Style"] = "dim underline" + metrics: Union[str, "Style"] = "italic" metrics_text_delimiter: str = " " metrics_format: str = ".3f" @@ -257,7 +258,7 @@ def __init__( refresh_rate: int = 1, leave: bool = False, theme: RichProgressBarTheme = RichProgressBarTheme(), - console_kwargs: Optional[Dict[str, Any]] = None, + console_kwargs: Optional[dict[str, Any]] = None, ) -> None: if not _RICH_AVAILABLE: raise ModuleNotFoundError( @@ -280,7 +281,6 @@ def __init__( self._metric_component: Optional[MetricsTextColumn] = None self._progress_stopped: bool = False self.theme = theme - self._update_for_light_colab_theme() @property def refresh_rate(self) -> float: @@ -318,13 +318,6 @@ def test_progress_bar(self) -> "Task": assert self.test_progress_bar_id is not None return self.progress.tasks[self.test_progress_bar_id] - def _update_for_light_colab_theme(self) -> None: - if _detect_light_colab_theme(): - attributes = ["description", "batch_progress", "metrics"] - for attr in attributes: - if getattr(self.theme, attr) == "white": - setattr(self.theme, attr, "black") - @override def disable(self) -> None: self._enabled = False @@ -449,7 +442,7 @@ def on_validation_batch_start( def _add_task(self, total_batches: Union[int, float], description: str, visible: bool = True) -> "TaskID": assert self.progress is not None return self.progress.add_task( - f"[{self.theme.description}]{description}", + f"[{self.theme.description}]{description}" if self.theme.description else description, total=total_batches, visible=visible, ) @@ -650,26 +643,9 @@ def configure_columns(self, trainer: "pl.Trainer") -> list: ProcessingSpeedColumn(style=self.theme.processing_speed), ] - def __getstate__(self) -> Dict: + def __getstate__(self) -> dict: state = self.__dict__.copy() # both the console and progress object can hold thread lock objects that are not pickleable state["progress"] = None state["_console"] = None return state - - -def _detect_light_colab_theme() -> bool: - """Detect if it's light theme in Colab.""" - try: - import get_ipython - except (NameError, ModuleNotFoundError): - return False - ipython = get_ipython() - if "google.colab" in str(ipython.__class__): - try: - from google.colab import output - - return output.eval_js('document.documentElement.matches("[theme=light]")') - except ModuleNotFoundError: - return False - return False diff --git a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py index cf9cd71614674..4ef260f00006d 100644 --- a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py +++ b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py @@ -15,7 +15,7 @@ import math import os import sys -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from typing_extensions import override @@ -115,7 +115,7 @@ def __init__(self, refresh_rate: int = 1, process_position: int = 0, leave: bool self._predict_progress_bar: Optional[_tqdm] = None self._leave = leave - def __getstate__(self) -> Dict: + def __getstate__(self) -> dict: # can't pickle the tqdm objects return {k: v if not isinstance(v, _tqdm) else None for k, v in vars(self).items()} diff --git a/src/lightning/pytorch/callbacks/pruning.py b/src/lightning/pytorch/callbacks/pruning.py index e83a9de06375c..1517ef6920b0d 100644 --- a/src/lightning/pytorch/callbacks/pruning.py +++ b/src/lightning/pytorch/callbacks/pruning.py @@ -18,9 +18,10 @@ import inspect import logging +from collections.abc import Sequence from copy import deepcopy from functools import partial -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Optional, Union import torch.nn.utils.prune as pytorch_prune from lightning_utilities.core.apply_func import apply_to_collection @@ -49,14 +50,14 @@ "random_unstructured": pytorch_prune.RandomUnstructured, } -_PARAM_TUPLE = Tuple[nn.Module, str] +_PARAM_TUPLE = tuple[nn.Module, str] _PARAM_LIST = Sequence[_PARAM_TUPLE] _MODULE_CONTAINERS = (LightningModule, nn.Sequential, nn.ModuleList, nn.ModuleDict) class _LayerRef(TypedDict): data: nn.Module - names: List[Tuple[int, str]] + names: list[tuple[int, str]] class ModelPruning(Callback): @@ -66,7 +67,7 @@ def __init__( self, pruning_fn: Union[Callable, str], parameters_to_prune: _PARAM_LIST = (), - parameter_names: Optional[List[str]] = None, + parameter_names: Optional[list[str]] = None, use_global_unstructured: bool = True, amount: Union[int, float, Callable[[int], Union[int, float]]] = 0.5, apply_pruning: Union[bool, Callable[[int], bool]] = True, @@ -165,8 +166,8 @@ def __init__( self._resample_parameters = resample_parameters self._prune_on_train_epoch_end = prune_on_train_epoch_end self._parameter_names = parameter_names or self.PARAMETER_NAMES - self._global_kwargs: Dict[str, Any] = {} - self._original_layers: Optional[Dict[int, _LayerRef]] = None + self._global_kwargs: dict[str, Any] = {} + self._original_layers: Optional[dict[int, _LayerRef]] = None self._pruning_method_name: Optional[str] = None for name in self._parameter_names: @@ -310,7 +311,7 @@ def _apply_local_pruning(self, amount: float) -> None: for module, name in self._parameters_to_prune: self.pruning_fn(module, name=name, amount=amount) # type: ignore[call-arg] - def _resolve_global_kwargs(self, amount: float) -> Dict[str, Any]: + def _resolve_global_kwargs(self, amount: float) -> dict[str, Any]: self._global_kwargs["amount"] = amount params = set(inspect.signature(self.pruning_fn).parameters) params.discard("self") @@ -322,7 +323,7 @@ def _apply_global_pruning(self, amount: float) -> None: ) @staticmethod - def _get_pruned_stats(module: nn.Module, name: str) -> Tuple[int, int]: + def _get_pruned_stats(module: nn.Module, name: str) -> tuple[int, int]: attr = f"{name}_mask" if not hasattr(module, attr): return 0, 1 @@ -345,7 +346,7 @@ def apply_pruning(self, amount: Union[int, float]) -> None: @rank_zero_only def _log_sparsity_stats( - self, prev: List[Tuple[int, int]], curr: List[Tuple[int, int]], amount: Union[int, float] = 0 + self, prev: list[tuple[int, int]], curr: list[tuple[int, int]], amount: Union[int, float] = 0 ) -> None: total_params = sum(p.numel() for layer, _ in self._parameters_to_prune for p in layer.parameters()) prev_total_zeros = sum(zeros for zeros, _ in prev) @@ -414,7 +415,7 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: LightningModule) -> Non rank_zero_debug("`ModelPruning.on_train_end`. Pruning is made permanent for this checkpoint") self.make_pruning_permanent(pl_module) - def _make_pruning_permanent_on_state_dict(self, pl_module: LightningModule) -> Dict[str, Any]: + def _make_pruning_permanent_on_state_dict(self, pl_module: LightningModule) -> dict[str, Any]: state_dict = pl_module.state_dict() # find the mask and the original weights. @@ -432,7 +433,7 @@ def move_to_cpu(tensor: Tensor) -> Tensor: return apply_to_collection(state_dict, Tensor, move_to_cpu) @override - def on_save_checkpoint(self, trainer: "pl.Trainer", pl_module: LightningModule, checkpoint: Dict[str, Any]) -> None: + def on_save_checkpoint(self, trainer: "pl.Trainer", pl_module: LightningModule, checkpoint: dict[str, Any]) -> None: if self._make_pruning_permanent: rank_zero_debug("`ModelPruning.on_save_checkpoint`. Pruning is made permanent for this checkpoint") # manually prune the weights so training can keep going with the same buffers diff --git a/src/lightning/pytorch/callbacks/rich_model_summary.py b/src/lightning/pytorch/callbacks/rich_model_summary.py index c6c429b4bd2f5..e4027f0dedcb1 100644 --- a/src/lightning/pytorch/callbacks/rich_model_summary.py +++ b/src/lightning/pytorch/callbacks/rich_model_summary.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Tuple +from typing import Any from typing_extensions import override @@ -67,11 +67,11 @@ def __init__(self, max_depth: int = 1, **summarize_kwargs: Any) -> None: @staticmethod @override def summarize( - summary_data: List[Tuple[str, List[str]]], + summary_data: list[tuple[str, list[str]]], total_parameters: int, trainable_parameters: int, model_size: float, - total_training_modes: Dict[str, int], + total_training_modes: dict[str, int], **summarize_kwargs: Any, ) -> None: from rich import get_console diff --git a/src/lightning/pytorch/callbacks/spike.py b/src/lightning/pytorch/callbacks/spike.py index 725d6f64333a6..b006acd44dcdb 100644 --- a/src/lightning/pytorch/callbacks/spike.py +++ b/src/lightning/pytorch/callbacks/spike.py @@ -1,5 +1,6 @@ import os -from typing import Any, Mapping, Union +from collections.abc import Mapping +from typing import Any, Union import torch diff --git a/src/lightning/pytorch/callbacks/stochastic_weight_avg.py b/src/lightning/pytorch/callbacks/stochastic_weight_avg.py index 737084ced426d..5643a038e00c1 100644 --- a/src/lightning/pytorch/callbacks/stochastic_weight_avg.py +++ b/src/lightning/pytorch/callbacks/stochastic_weight_avg.py @@ -17,7 +17,7 @@ """ from copy import deepcopy -from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast +from typing import Any, Callable, Literal, Optional, Union, cast import torch from torch import Tensor, nn @@ -39,7 +39,7 @@ class StochasticWeightAveraging(Callback): def __init__( self, - swa_lrs: Union[float, List[float]], + swa_lrs: Union[float, list[float]], swa_epoch_start: Union[int, float] = 0.8, annealing_epochs: int = 10, annealing_strategy: Literal["cos", "linear"] = "cos", @@ -126,10 +126,10 @@ def __init__( self._average_model: Optional[pl.LightningModule] = None self._initialized = False self._swa_scheduler: Optional[LRScheduler] = None - self._scheduler_state: Optional[Dict] = None + self._scheduler_state: Optional[dict] = None self._init_n_averaged = 0 self._latest_update_epoch = -1 - self.momenta: Dict[nn.modules.batchnorm._BatchNorm, Optional[float]] = {} + self.momenta: dict[nn.modules.batchnorm._BatchNorm, Optional[float]] = {} self._max_epochs: int @property @@ -331,7 +331,7 @@ def avg_fn(averaged_model_parameter: Tensor, model_parameter: Tensor, num_averag return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1) @override - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return { "n_averaged": 0 if self.n_averaged is None else self.n_averaged.item(), "latest_update_epoch": self._latest_update_epoch, @@ -340,7 +340,7 @@ def state_dict(self) -> Dict[str, Any]: } @override - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: self._init_n_averaged = state_dict["n_averaged"] self._latest_update_epoch = state_dict["latest_update_epoch"] self._scheduler_state = state_dict["scheduler_state"] diff --git a/src/lightning/pytorch/callbacks/throughput_monitor.py b/src/lightning/pytorch/callbacks/throughput_monitor.py index a2d73d83184b1..a49610a912e57 100644 --- a/src/lightning/pytorch/callbacks/throughput_monitor.py +++ b/src/lightning/pytorch/callbacks/throughput_monitor.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import time -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch from typing_extensions import override @@ -84,9 +84,9 @@ def __init__( self.batch_size_fn = batch_size_fn self.length_fn = length_fn self.available_flops: Optional[int] = None - self._throughputs: Dict[RunningStage, Throughput] = {} - self._t0s: Dict[RunningStage, float] = {} - self._lengths: Dict[RunningStage, int] = {} + self._throughputs: dict[RunningStage, Throughput] = {} + self._t0s: dict[RunningStage, float] = {} + self._lengths: dict[RunningStage, int] = {} @override def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) -> None: diff --git a/src/lightning/pytorch/callbacks/timer.py b/src/lightning/pytorch/callbacks/timer.py index e1bed4adb9889..b6b74d280427c 100644 --- a/src/lightning/pytorch/callbacks/timer.py +++ b/src/lightning/pytorch/callbacks/timer.py @@ -20,7 +20,7 @@ import re import time from datetime import timedelta -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from typing_extensions import override @@ -83,7 +83,7 @@ class Timer(Callback): def __init__( self, - duration: Optional[Union[str, timedelta, Dict[str, int]]] = None, + duration: Optional[Union[str, timedelta, dict[str, int]]] = None, interval: str = Interval.step, verbose: bool = True, ) -> None: @@ -111,8 +111,8 @@ def __init__( self._duration = duration.total_seconds() if duration is not None else None self._interval = interval self._verbose = verbose - self._start_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage} - self._end_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage} + self._start_time: dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage} + self._end_time: dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage} self._offset = 0 def start_time(self, stage: str = RunningStage.TRAINING) -> Optional[float]: @@ -187,11 +187,11 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) - self._check_time_remaining(trainer) @override - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return {"time_elapsed": {stage.value: self.time_elapsed(stage) for stage in RunningStage}} @override - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: time_elapsed = state_dict.get("time_elapsed", {}) self._offset = time_elapsed.get(RunningStage.TRAINING.value, 0) diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index 26af335f7be93..5e70a9bdd8b9e 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -14,9 +14,10 @@ import inspect import os import sys +from collections.abc import Iterable from functools import partial, update_wrapper from types import MethodType -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, TypeVar, Union +from typing import Any, Callable, Optional, TypeVar, Union import torch import yaml @@ -36,6 +37,18 @@ _JSONARGPARSE_SIGNATURES_AVAILABLE = RequirementCache("jsonargparse[signatures]>=4.27.7") + +def patch_jsonargparse_python_3_12_8() -> None: + if sys.version_info < (3, 12, 8): + return + + def _parse_known_args_patch(self: ArgumentParser, args: Any = None, namespace: Any = None) -> tuple[Any, Any]: + namespace, args = super(ArgumentParser, self)._parse_known_args(args, namespace, intermixed=False) # type: ignore + return namespace, args + + setattr(ArgumentParser, "_parse_known_args", _parse_known_args_patch) + + if _JSONARGPARSE_SIGNATURES_AVAILABLE: import docstring_parser from jsonargparse import ( @@ -47,6 +60,8 @@ set_config_read_mode, ) + patch_jsonargparse_python_3_12_8() # Required until fix https://github.com/omni-us/jsonargparse/issues/641 + register_unresolvable_import_paths(torch) # Required until fix https://github.com/pytorch/pytorch/issues/74483 set_config_read_mode(fsspec_enabled=True) else: @@ -65,11 +80,11 @@ def __init__(self, optimizer: Optimizer, monitor: str, *args: Any, **kwargs: Any # LightningCLI requires the ReduceLROnPlateau defined here, thus it shouldn't accept the one from pytorch: LRSchedulerTypeTuple = (LRScheduler, ReduceLROnPlateau) LRSchedulerTypeUnion = Union[LRScheduler, ReduceLROnPlateau] -LRSchedulerType = Union[Type[LRScheduler], Type[ReduceLROnPlateau]] +LRSchedulerType = Union[type[LRScheduler], type[ReduceLROnPlateau]] # Type aliases intended for convenience of CLI developers -ArgsType = Optional[Union[List[str], Dict[str, Any], Namespace]] +ArgsType = Optional[Union[list[str], dict[str, Any], Namespace]] OptimizerCallable = Callable[[Iterable], Optimizer] LRSchedulerCallable = Callable[[Optimizer], Union[LRScheduler, ReduceLROnPlateau]] @@ -99,24 +114,24 @@ def __init__( if not _JSONARGPARSE_SIGNATURES_AVAILABLE: raise ModuleNotFoundError(f"{_JSONARGPARSE_SIGNATURES_AVAILABLE}") super().__init__(*args, description=description, env_prefix=env_prefix, default_env=default_env, **kwargs) - self.callback_keys: List[str] = [] + self.callback_keys: list[str] = [] # separate optimizers and lr schedulers to know which were added - self._optimizers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {} - self._lr_schedulers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {} + self._optimizers: dict[str, tuple[Union[type, tuple[type, ...]], str]] = {} + self._lr_schedulers: dict[str, tuple[Union[type, tuple[type, ...]], str]] = {} def add_lightning_class_args( self, lightning_class: Union[ Callable[..., Union[Trainer, LightningModule, LightningDataModule, Callback]], - Type[Trainer], - Type[LightningModule], - Type[LightningDataModule], - Type[Callback], + type[Trainer], + type[LightningModule], + type[LightningDataModule], + type[Callback], ], nested_key: str, subclass_mode: bool = False, required: bool = True, - ) -> List[str]: + ) -> list[str]: """Adds arguments from a lightning class to a nested key of the parser. Args: @@ -153,7 +168,7 @@ def add_lightning_class_args( def add_optimizer_args( self, - optimizer_class: Union[Type[Optimizer], Tuple[Type[Optimizer], ...]] = (Optimizer,), + optimizer_class: Union[type[Optimizer], tuple[type[Optimizer], ...]] = (Optimizer,), nested_key: str = "optimizer", link_to: str = "AUTOMATIC", ) -> None: @@ -169,7 +184,7 @@ def add_optimizer_args( assert all(issubclass(o, Optimizer) for o in optimizer_class) else: assert issubclass(optimizer_class, Optimizer) - kwargs: Dict[str, Any] = {"instantiate": False, "fail_untyped": False, "skip": {"params"}} + kwargs: dict[str, Any] = {"instantiate": False, "fail_untyped": False, "skip": {"params"}} if isinstance(optimizer_class, tuple): self.add_subclass_arguments(optimizer_class, nested_key, **kwargs) else: @@ -178,7 +193,7 @@ def add_optimizer_args( def add_lr_scheduler_args( self, - lr_scheduler_class: Union[LRSchedulerType, Tuple[LRSchedulerType, ...]] = LRSchedulerTypeTuple, + lr_scheduler_class: Union[LRSchedulerType, tuple[LRSchedulerType, ...]] = LRSchedulerTypeTuple, nested_key: str = "lr_scheduler", link_to: str = "AUTOMATIC", ) -> None: @@ -195,7 +210,7 @@ def add_lr_scheduler_args( assert all(issubclass(o, LRSchedulerTypeTuple) for o in lr_scheduler_class) else: assert issubclass(lr_scheduler_class, LRSchedulerTypeTuple) - kwargs: Dict[str, Any] = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}} + kwargs: dict[str, Any] = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}} if isinstance(lr_scheduler_class, tuple): self.add_subclass_arguments(lr_scheduler_class, nested_key, **kwargs) else: @@ -305,14 +320,14 @@ class LightningCLI: def __init__( self, - model_class: Optional[Union[Type[LightningModule], Callable[..., LightningModule]]] = None, - datamodule_class: Optional[Union[Type[LightningDataModule], Callable[..., LightningDataModule]]] = None, - save_config_callback: Optional[Type[SaveConfigCallback]] = SaveConfigCallback, - save_config_kwargs: Optional[Dict[str, Any]] = None, - trainer_class: Union[Type[Trainer], Callable[..., Trainer]] = Trainer, - trainer_defaults: Optional[Dict[str, Any]] = None, + model_class: Optional[Union[type[LightningModule], Callable[..., LightningModule]]] = None, + datamodule_class: Optional[Union[type[LightningDataModule], Callable[..., LightningDataModule]]] = None, + save_config_callback: Optional[type[SaveConfigCallback]] = SaveConfigCallback, + save_config_kwargs: Optional[dict[str, Any]] = None, + trainer_class: Union[type[Trainer], Callable[..., Trainer]] = Trainer, + trainer_defaults: Optional[dict[str, Any]] = None, seed_everything_default: Union[bool, int] = True, - parser_kwargs: Optional[Union[Dict[str, Any], Dict[str, Dict[str, Any]]]] = None, + parser_kwargs: Optional[Union[dict[str, Any], dict[str, dict[str, Any]]]] = None, subclass_mode_model: bool = False, subclass_mode_data: bool = False, args: ArgsType = None, @@ -389,11 +404,12 @@ def __init__( self._add_instantiators() self.before_instantiate_classes() self.instantiate_classes() + self.after_instantiate_classes() if self.subcommand is not None: self._run_subcommand(self.subcommand) - def _setup_parser_kwargs(self, parser_kwargs: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + def _setup_parser_kwargs(self, parser_kwargs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: subcommand_names = self.subcommands().keys() main_kwargs = {k: v for k, v in parser_kwargs.items() if k not in subcommand_names} subparser_kwargs = {k: v for k, v in parser_kwargs.items() if k in subcommand_names} @@ -409,12 +425,12 @@ def init_parser(self, **kwargs: Any) -> LightningArgumentParser: return parser def setup_parser( - self, add_subcommands: bool, main_kwargs: Dict[str, Any], subparser_kwargs: Dict[str, Any] + self, add_subcommands: bool, main_kwargs: dict[str, Any], subparser_kwargs: dict[str, Any] ) -> None: """Initialize and setup the parser, subcommands, and arguments.""" self.parser = self.init_parser(**main_kwargs) if add_subcommands: - self._subcommand_method_arguments: Dict[str, List[str]] = {} + self._subcommand_method_arguments: dict[str, list[str]] = {} self._add_subcommands(self.parser, **subparser_kwargs) else: self._add_arguments(self.parser) @@ -469,7 +485,7 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: """ @staticmethod - def subcommands() -> Dict[str, Set[str]]: + def subcommands() -> dict[str, set[str]]: """Defines the list of available subcommands and the arguments to skip.""" return { "fit": {"model", "train_dataloaders", "val_dataloaders", "datamodule"}, @@ -480,7 +496,7 @@ def subcommands() -> Dict[str, Set[str]]: def _add_subcommands(self, parser: LightningArgumentParser, **kwargs: Any) -> None: """Adds subcommands to the input parser.""" - self._subcommand_parsers: Dict[str, LightningArgumentParser] = {} + self._subcommand_parsers: dict[str, LightningArgumentParser] = {} parser_subcommands = parser.add_subcommands() # the user might have passed a builder function trainer_class = ( @@ -497,11 +513,11 @@ def _add_subcommands(self, parser: LightningArgumentParser, **kwargs: Any) -> No self._subcommand_parsers[subcommand] = subcommand_parser parser_subcommands.add_subcommand(subcommand, subcommand_parser, help=description) - def _prepare_subcommand_parser(self, klass: Type, subcommand: str, **kwargs: Any) -> LightningArgumentParser: + def _prepare_subcommand_parser(self, klass: type, subcommand: str, **kwargs: Any) -> LightningArgumentParser: parser = self.init_parser(**kwargs) self._add_arguments(parser) # subcommand arguments - skip: Set[Union[str, int]] = set(self.subcommands()[subcommand]) + skip: set[Union[str, int]] = set(self.subcommands()[subcommand]) added = parser.add_method_arguments(klass, subcommand, skip=skip) # need to save which arguments were added to pass them to the method later self._subcommand_method_arguments[subcommand] = added @@ -560,6 +576,9 @@ def instantiate_classes(self) -> None: self._add_configure_optimizers_method_to_model(self.subcommand) self.trainer = self.instantiate_trainer() + def after_instantiate_classes(self) -> None: + """Implement to run some code after instantiating the classes.""" + def instantiate_trainer(self, **kwargs: Any) -> Trainer: """Instantiates the trainer. @@ -571,7 +590,7 @@ def instantiate_trainer(self, **kwargs: Any) -> Trainer: trainer_config = {**self._get(self.config_init, "trainer", default={}), **kwargs} return self._instantiate_trainer(trainer_config, extra_callbacks) - def _instantiate_trainer(self, config: Dict[str, Any], callbacks: List[Callback]) -> Trainer: + def _instantiate_trainer(self, config: dict[str, Any], callbacks: list[Callback]) -> Trainer: key = "callbacks" if key in config: if config[key] is None: @@ -632,8 +651,8 @@ def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) - parser = self._parser(subcommand) def get_automatic( - class_type: Union[Type, Tuple[Type, ...]], register: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] - ) -> List[str]: + class_type: Union[type, tuple[type, ...]], register: dict[str, tuple[Union[type, tuple[type, ...]], str]] + ) -> list[str]: automatic = [] for key, (base_class, link_to) in register.items(): if not isinstance(base_class, tuple): @@ -704,7 +723,7 @@ def _run_subcommand(self, subcommand: str) -> None: if callable(after_fn): after_fn() - def _prepare_subcommand_kwargs(self, subcommand: str) -> Dict[str, Any]: + def _prepare_subcommand_kwargs(self, subcommand: str) -> dict[str, Any]: """Prepares the keyword arguments to pass to the subcommand to run.""" fn_kwargs = { k: v for k, v in self.config_init[subcommand].items() if k in self._subcommand_method_arguments[subcommand] @@ -730,26 +749,26 @@ def _set_seed(self) -> None: self.config["seed_everything"] = config_seed -def _class_path_from_class(class_type: Type) -> str: +def _class_path_from_class(class_type: type) -> str: return class_type.__module__ + "." + class_type.__name__ def _global_add_class_path( - class_type: Type, init_args: Optional[Union[Namespace, Dict[str, Any]]] = None -) -> Dict[str, Any]: + class_type: type, init_args: Optional[Union[Namespace, dict[str, Any]]] = None +) -> dict[str, Any]: if isinstance(init_args, Namespace): init_args = init_args.as_dict() return {"class_path": _class_path_from_class(class_type), "init_args": init_args or {}} -def _add_class_path_generator(class_type: Type) -> Callable[[Namespace], Dict[str, Any]]: - def add_class_path(init_args: Namespace) -> Dict[str, Any]: +def _add_class_path_generator(class_type: type) -> Callable[[Namespace], dict[str, Any]]: + def add_class_path(init_args: Namespace) -> dict[str, Any]: return _global_add_class_path(class_type, init_args) return add_class_path -def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any: +def instantiate_class(args: Union[Any, tuple[Any, ...]], init: dict[str, Any]) -> Any: """Instantiates a class with the given args and init. Args: @@ -790,7 +809,7 @@ def __init__(self, cli: LightningCLI, key: str) -> None: self.cli = cli self.key = key - def __call__(self, class_type: Type[ModuleType], *args: Any, **kwargs: Any) -> ModuleType: + def __call__(self, class_type: type[ModuleType], *args: Any, **kwargs: Any) -> ModuleType: hparams = self.cli.config_dump.get(self.key, {}) if "class_path" in hparams: # To make hparams backwards compatible, and so that it is the same irrespective of subclass_mode, the @@ -808,7 +827,7 @@ def __call__(self, class_type: Type[ModuleType], *args: Any, **kwargs: Any) -> M return class_type(*args, **kwargs) -def instantiate_module(class_type: Type[ModuleType], config: Dict[str, Any]) -> ModuleType: +def instantiate_module(class_type: type[ModuleType], config: dict[str, Any]) -> ModuleType: parser = ArgumentParser(exit_on_error=False) if "_class_path" in config: parser.add_subclass_arguments(class_type, "module", fail_untyped=False) diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index 6cb8f79f09284..ff84c2fd8b199 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -14,7 +14,9 @@ """LightningDataModule for loading DataLoaders with ease.""" import inspect -from typing import IO, Any, Dict, Iterable, Optional, Union, cast +import os +from collections.abc import Iterable, Sized +from typing import IO, Any, Optional, Union, cast from lightning_utilities import apply_to_collection from torch.utils.data import DataLoader, Dataset, IterableDataset @@ -147,7 +149,7 @@ def predict_dataloader() -> EVAL_DATALOADERS: datamodule.predict_dataloader = predict_dataloader # type: ignore[method-assign] return datamodule - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: """Called when saving a checkpoint, implement to generate and save datamodule state. Returns: @@ -156,7 +158,7 @@ def state_dict(self) -> Dict[str, Any]: """ return {} - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: """Called when loading a checkpoint, implement to reload datamodule state given datamodule state_dict. Args: @@ -243,3 +245,75 @@ def load_from_checkpoint( **kwargs, ) return cast(Self, loaded) + + def __str__(self) -> str: + """Return a string representation of the datasets that are set up. + + Returns: + A string representation of the datasets that are setup. + + """ + + class dataset_info: + def __init__(self, available: bool, length: str) -> None: + self.available = available + self.length = length + + def retrieve_dataset_info(loader: DataLoader) -> dataset_info: + """Helper function to compute dataset information.""" + dataset = loader.dataset + size: str = str(len(dataset)) if isinstance(dataset, Sized) else "NA" + + return dataset_info(True, size) + + def loader_info( + loader: Union[DataLoader, Iterable[DataLoader]], + ) -> Union[dataset_info, Iterable[dataset_info]]: + """Helper function to compute dataset information.""" + return apply_to_collection(loader, DataLoader, retrieve_dataset_info) + + def extract_loader_info(methods: list[tuple[str, str]]) -> dict: + """Helper function to extract information for each dataloader method.""" + info: dict[str, Union[dataset_info, Iterable[dataset_info]]] = {} + for loader_name, func_name in methods: + loader_method = getattr(self, func_name, None) + + try: + loader = loader_method() # type: ignore + info[loader_name] = loader_info(loader) + except Exception: + info[loader_name] = dataset_info(False, "") + + return info + + def format_loader_info(info: dict[str, Union[dataset_info, Iterable[dataset_info]]]) -> str: + """Helper function to format loader information.""" + output = [] + for loader_name, loader_info in info.items(): + # Single dataset + if isinstance(loader_info, dataset_info): + loader_info_formatted = "None" if not loader_info.available else f"size={loader_info.length}" + # Iterable of datasets + else: + loader_info_formatted = " ; ".join( + "None" if not loader_info_i.available else f"{i}. size={loader_info_i.length}" + for i, loader_info_i in enumerate(loader_info, start=1) + ) + + output.append(f"{{{loader_name}: {loader_info_formatted}}}") + + return os.linesep.join(output) + + # Available dataloader methods + datamodule_loader_methods: list[tuple[str, str]] = [ + ("Train dataloader", "train_dataloader"), + ("Validation dataloader", "val_dataloader"), + ("Test dataloader", "test_dataloader"), + ("Predict dataloader", "predict_dataloader"), + ] + + # Retrieve information for each dataloader method + dataloader_info = extract_loader_info(datamodule_loader_methods) + # Format the information + dataloader_str = format_loader_info(dataloader_info) + return dataloader_str diff --git a/src/lightning/pytorch/core/hooks.py b/src/lightning/pytorch/core/hooks.py index 5495a0262036d..0b0ab14244e38 100644 --- a/src/lightning/pytorch/core/hooks.py +++ b/src/lightning/pytorch/core/hooks.py @@ -13,7 +13,7 @@ # limitations under the License. """Various hooks to be used in the Lightning code.""" -from typing import Any, Dict, Optional +from typing import Any, Optional import torch from torch import Tensor @@ -670,7 +670,7 @@ def on_after_batch_transfer(self, batch, dataloader_idx): class CheckpointHooks: """Hooks to be used with Checkpointing.""" - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None: r"""Called by Lightning to restore your model. If you saved something with :meth:`on_save_checkpoint` this is your chance to restore this. @@ -689,7 +689,7 @@ def on_load_checkpoint(self, checkpoint): """ - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None: r"""Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save. diff --git a/src/lightning/pytorch/core/mixins/hparams_mixin.py b/src/lightning/pytorch/core/mixins/hparams_mixin.py index 94ece0039d4f4..3a01cd2fe9a7c 100644 --- a/src/lightning/pytorch/core/mixins/hparams_mixin.py +++ b/src/lightning/pytorch/core/mixins/hparams_mixin.py @@ -15,9 +15,10 @@ import inspect import types from argparse import Namespace +from collections.abc import Iterator, MutableMapping, Sequence from contextlib import contextmanager from contextvars import ContextVar -from typing import Any, Iterator, List, MutableMapping, Optional, Sequence, Union +from typing import Any, Optional, Union from lightning.fabric.utilities.data import AttributeDict from lightning.pytorch.utilities.parsing import save_hyperparameters @@ -41,7 +42,7 @@ def _given_hyperparameters_context(hparams: dict, instantiator: str) -> Iterator class HyperparametersMixin: - __jit_unused_properties__: List[str] = ["hparams", "hparams_initial"] + __jit_unused_properties__: list[str] = ["hparams", "hparams_initial"] def __init__(self) -> None: super().__init__() diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 782fc40d928ef..f1d1da924eac4 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -16,6 +16,7 @@ import logging import numbers import weakref +from collections.abc import Generator, Mapping, Sequence from contextlib import contextmanager from io import BytesIO from pathlib import Path @@ -24,14 +25,8 @@ TYPE_CHECKING, Any, Callable, - Dict, - Generator, - List, Literal, - Mapping, Optional, - Sequence, - Tuple, Union, cast, overload, @@ -86,7 +81,7 @@ log = logging.getLogger(__name__) MODULE_OPTIMIZERS = Union[ - Optimizer, LightningOptimizer, _FabricOptimizer, List[Optimizer], List[LightningOptimizer], List[_FabricOptimizer] + Optimizer, LightningOptimizer, _FabricOptimizer, list[Optimizer], list[LightningOptimizer], list[_FabricOptimizer] ] @@ -100,7 +95,7 @@ class LightningModule( ): # Below is for property support of JIT # since none of these are important when using JIT, we are going to ignore them. - __jit_unused_properties__: List[str] = ( + __jit_unused_properties__: list[str] = ( [ "example_input_array", "on_gpu", @@ -132,19 +127,19 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._trainer: Optional[pl.Trainer] = None # attributes that can be set by user - self._example_input_array: Optional[Union[Tensor, Tuple, Dict]] = None + self._example_input_array: Optional[Union[Tensor, tuple, dict]] = None self._automatic_optimization: bool = True self._strict_loading: Optional[bool] = None # attributes used internally self._current_fx_name: Optional[str] = None - self._param_requires_grad_state: Dict[str, bool] = {} - self._metric_attributes: Optional[Dict[int, str]] = None - self._compiler_ctx: Optional[Dict[str, Any]] = None + self._param_requires_grad_state: dict[str, bool] = {} + self._metric_attributes: Optional[dict[int, str]] = None + self._compiler_ctx: Optional[dict[str, Any]] = None # attributes only used when using fabric self._fabric: Optional[lf.Fabric] = None - self._fabric_optimizers: List[_FabricOptimizer] = [] + self._fabric_optimizers: list[_FabricOptimizer] = [] # access to device mesh in `conigure_model()` hook self._device_mesh: Optional[DeviceMesh] = None @@ -152,10 +147,10 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: @overload def optimizers( self, use_pl_optimizer: Literal[True] = True - ) -> Union[LightningOptimizer, List[LightningOptimizer]]: ... + ) -> Union[LightningOptimizer, list[LightningOptimizer]]: ... @overload - def optimizers(self, use_pl_optimizer: Literal[False]) -> Union[Optimizer, List[Optimizer]]: ... + def optimizers(self, use_pl_optimizer: Literal[False]) -> Union[Optimizer, list[Optimizer]]: ... @overload def optimizers(self, use_pl_optimizer: bool) -> MODULE_OPTIMIZERS: ... @@ -190,7 +185,7 @@ def optimizers(self, use_pl_optimizer: bool = True) -> MODULE_OPTIMIZERS: # multiple opts return opts - def lr_schedulers(self) -> Union[None, List[LRSchedulerPLType], LRSchedulerPLType]: + def lr_schedulers(self) -> Union[None, list[LRSchedulerPLType], LRSchedulerPLType]: """Returns the learning rate scheduler(s) that are being used during training. Useful for manual optimization. Returns: @@ -202,7 +197,7 @@ def lr_schedulers(self) -> Union[None, List[LRSchedulerPLType], LRSchedulerPLTyp return None # ignore other keys "interval", "frequency", etc. - lr_schedulers: List[LRSchedulerPLType] = [config.scheduler for config in self.trainer.lr_scheduler_configs] + lr_schedulers: list[LRSchedulerPLType] = [config.scheduler for config in self.trainer.lr_scheduler_configs] # single scheduler if len(lr_schedulers) == 1: @@ -240,7 +235,7 @@ def fabric(self, fabric: Optional["lf.Fabric"]) -> None: self._fabric = fabric @property - def example_input_array(self) -> Optional[Union[Tensor, Tuple, Dict]]: + def example_input_array(self) -> Optional[Union[Tensor, tuple, dict]]: """The example input array is a specification of what the module can consume in the :meth:`forward` method. The return type is interpreted as follows: @@ -255,7 +250,7 @@ def example_input_array(self) -> Optional[Union[Tensor, Tuple, Dict]]: return self._example_input_array @example_input_array.setter - def example_input_array(self, example: Optional[Union[Tensor, Tuple, Dict]]) -> None: + def example_input_array(self, example: Optional[Union[Tensor, tuple, dict]]) -> None: self._example_input_array = example @property @@ -318,7 +313,7 @@ def logger(self) -> Optional[Union[Logger, FabricLogger]]: return self._trainer.logger if self._trainer is not None else None @property - def loggers(self) -> Union[List[Logger], List[FabricLogger]]: + def loggers(self) -> Union[list[Logger], list[FabricLogger]]: """Reference to the list of loggers in the Trainer.""" if self._fabric is not None: return self._fabric.loggers @@ -531,7 +526,7 @@ def log( logger=logger, on_step=on_step, on_epoch=on_epoch, - reduce_fx=reduce_fx, # type: ignore[arg-type] + reduce_fx=reduce_fx, enable_graph=enable_graph, add_dataloader_idx=add_dataloader_idx, batch_size=batch_size, @@ -599,7 +594,7 @@ def log_dict( if self._fabric is not None: return self._log_dict_through_fabric(dictionary=dictionary, logger=logger) - kwargs: Dict[str, bool] = {} + kwargs: dict[str, bool] = {} if isinstance(dictionary, MetricCollection): kwargs["keep_base"] = False @@ -665,8 +660,8 @@ def __to_tensor(self, value: Union[Tensor, numbers.Number], name: str) -> Tensor return value def all_gather( - self, data: Union[Tensor, Dict, List, Tuple], group: Optional[Any] = None, sync_grads: bool = False - ) -> Union[Tensor, Dict, List, Tuple]: + self, data: Union[Tensor, dict, list, tuple], group: Optional[Any] = None, sync_grads: bool = False + ) -> Union[Tensor, dict, list, tuple]: r"""Gather tensors or collections of tensors from multiple processes. This method needs to be called on all processes and the tensors need to have the same shape across all @@ -1405,7 +1400,9 @@ def forward(self, x): input_sample = self._apply_batch_transfer_handler(input_sample) file_path = str(file_path) if isinstance(file_path, Path) else file_path - torch.onnx.export(self, input_sample, file_path, **kwargs) + # PyTorch (2.5) declares file_path to be str | PathLike[Any] | None, but + # BytesIO does work, too. + torch.onnx.export(self, input_sample, file_path, **kwargs) # type: ignore self.train(mode) @torch.no_grad() @@ -1415,7 +1412,7 @@ def to_torchscript( method: Optional[str] = "script", example_inputs: Optional[Any] = None, **kwargs: Any, - ) -> Union[ScriptModule, Dict[str, ScriptModule]]: + ) -> Union[ScriptModule, dict[str, ScriptModule]]: """By default compiles the whole model to a :class:`~torch.jit.ScriptModule`. If you want to use tracing, please provided the argument ``method='trace'`` and make sure that either the `example_inputs` argument is provided, or the model has :attr:`example_input_array` set. If you would like to customize the modules that are @@ -1592,7 +1589,7 @@ def load_from_checkpoint( return cast(Self, loaded) @override - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: state = dict(self.__dict__) state["_trainer"] = None return state diff --git a/src/lightning/pytorch/core/optimizer.py b/src/lightning/pytorch/core/optimizer.py index 777dca0b51dfe..46126e212378e 100644 --- a/src/lightning/pytorch/core/optimizer.py +++ b/src/lightning/pytorch/core/optimizer.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Generator from contextlib import contextmanager from dataclasses import fields -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union, overload +from typing import Any, Callable, Optional, Union, overload from weakref import proxy import torch @@ -172,7 +173,7 @@ def __getattr__(self, item: Any) -> Any: def _init_optimizers_and_lr_schedulers( model: "pl.LightningModule", -) -> Tuple[List[Optimizer], List[LRSchedulerConfig]]: +) -> tuple[list[Optimizer], list[LRSchedulerConfig]]: """Calls `LightningModule.configure_optimizers` and parses and validates the output.""" from lightning.pytorch.trainer import call @@ -197,8 +198,8 @@ def _init_optimizers_and_lr_schedulers( def _configure_optimizers( - optim_conf: Union[Dict[str, Any], List, Optimizer, Tuple], -) -> Tuple[List, List, Optional[str]]: + optim_conf: Union[dict[str, Any], list, Optimizer, tuple], +) -> tuple[list, list, Optional[str]]: optimizers, lr_schedulers = [], [] monitor = None @@ -246,7 +247,7 @@ def _configure_optimizers( return optimizers, lr_schedulers, monitor -def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]) -> List[LRSchedulerConfig]: +def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]) -> list[LRSchedulerConfig]: """Convert each scheduler into `LRSchedulerConfig` with relevant information, when using automatic optimization.""" lr_scheduler_configs = [] for scheduler in schedulers: @@ -301,7 +302,7 @@ def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str] return lr_scheduler_configs -def _configure_schedulers_manual_opt(schedulers: list) -> List[LRSchedulerConfig]: +def _configure_schedulers_manual_opt(schedulers: list) -> list[LRSchedulerConfig]: """Convert each scheduler into `LRSchedulerConfig` structure with relevant information, when using manual optimization.""" lr_scheduler_configs = [] @@ -326,7 +327,7 @@ def _configure_schedulers_manual_opt(schedulers: list) -> List[LRSchedulerConfig return lr_scheduler_configs -def _validate_scheduler_api(lr_scheduler_configs: List[LRSchedulerConfig], model: "pl.LightningModule") -> None: +def _validate_scheduler_api(lr_scheduler_configs: list[LRSchedulerConfig], model: "pl.LightningModule") -> None: for config in lr_scheduler_configs: scheduler = config.scheduler if not isinstance(scheduler, _Stateful): @@ -347,7 +348,7 @@ def _validate_scheduler_api(lr_scheduler_configs: List[LRSchedulerConfig], model ) -def _validate_multiple_optimizers_support(optimizers: List[Optimizer], model: "pl.LightningModule") -> None: +def _validate_multiple_optimizers_support(optimizers: list[Optimizer], model: "pl.LightningModule") -> None: if is_param_in_hook_signature(model.training_step, "optimizer_idx", explicit=True): raise RuntimeError( "Training with multiple optimizers is only supported with manual optimization. Remove the `optimizer_idx`" @@ -362,7 +363,7 @@ def _validate_multiple_optimizers_support(optimizers: List[Optimizer], model: "p ) -def _validate_optimizers_attached(optimizers: List[Optimizer], lr_scheduler_configs: List[LRSchedulerConfig]) -> None: +def _validate_optimizers_attached(optimizers: list[Optimizer], lr_scheduler_configs: list[LRSchedulerConfig]) -> None: for config in lr_scheduler_configs: if config.scheduler.optimizer not in optimizers: raise MisconfigurationException( @@ -370,7 +371,7 @@ def _validate_optimizers_attached(optimizers: List[Optimizer], lr_scheduler_conf ) -def _validate_optim_conf(optim_conf: Dict[str, Any]) -> None: +def _validate_optim_conf(optim_conf: dict[str, Any]) -> None: valid_keys = {"optimizer", "lr_scheduler", "monitor"} extra_keys = optim_conf.keys() - valid_keys if extra_keys: @@ -387,15 +388,15 @@ def __init__(self) -> None: super().__init__([torch.zeros(1)], {}) @override - def add_param_group(self, param_group: Dict[Any, Any]) -> None: + def add_param_group(self, param_group: dict[Any, Any]) -> None: pass # Do Nothing @override - def load_state_dict(self, state_dict: Dict[Any, Any]) -> None: + def load_state_dict(self, state_dict: dict[Any, Any]) -> None: pass # Do Nothing @override - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return {} # Return Empty @overload diff --git a/src/lightning/pytorch/core/saving.py b/src/lightning/pytorch/core/saving.py index 521192f500b53..09d888c56bdcd 100644 --- a/src/lightning/pytorch/core/saving.py +++ b/src/lightning/pytorch/core/saving.py @@ -22,7 +22,7 @@ from copy import deepcopy from enum import Enum from pathlib import Path -from typing import IO, TYPE_CHECKING, Any, Callable, Dict, Optional, Type, Union +from typing import IO, TYPE_CHECKING, Any, Callable, Optional, Union from warnings import warn import torch @@ -51,7 +51,7 @@ def _load_from_checkpoint( - cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]], + cls: Union[type["pl.LightningModule"], type["pl.LightningDataModule"]], checkpoint_path: Union[_PATH, IO], map_location: _MAP_LOCATION_TYPE = None, hparams_file: Optional[_PATH] = None, @@ -115,8 +115,8 @@ def _default_map_location(storage: "UntypedStorage", location: str) -> Optional[ def _load_state( - cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]], - checkpoint: Dict[str, Any], + cls: Union[type["pl.LightningModule"], type["pl.LightningDataModule"]], + checkpoint: dict[str, Any], strict: Optional[bool] = None, **cls_kwargs_new: Any, ) -> Union["pl.LightningModule", "pl.LightningDataModule"]: @@ -200,8 +200,8 @@ def _load_state( def _convert_loaded_hparams( - model_args: Dict[str, Any], hparams_type: Optional[Union[Callable, str]] = None -) -> Dict[str, Any]: + model_args: dict[str, Any], hparams_type: Optional[Union[Callable, str]] = None +) -> dict[str, Any]: """Convert hparams according given type in callable or string (past) format.""" # if not hparams type define if not hparams_type: @@ -243,7 +243,7 @@ def update_hparams(hparams: dict, updates: dict) -> None: hparams.update({k: v}) -def load_hparams_from_tags_csv(tags_csv: _PATH) -> Dict[str, Any]: +def load_hparams_from_tags_csv(tags_csv: _PATH) -> dict[str, Any]: """Load hparams from a file. >>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here') @@ -281,7 +281,7 @@ def save_hparams_to_tags_csv(tags_csv: _PATH, hparams: Union[dict, Namespace]) - writer.writerow({"key": k, "value": v}) -def load_hparams_from_yaml(config_yaml: _PATH, use_omegaconf: bool = True) -> Dict[str, Any]: +def load_hparams_from_yaml(config_yaml: _PATH, use_omegaconf: bool = True) -> dict[str, Any]: """Load hparams from a file. Args: diff --git a/src/lightning/pytorch/demos/__init__.py b/src/lightning/pytorch/demos/__init__.py index fa91d7cac9fde..1e03e2fdfde95 100644 --- a/src/lightning/pytorch/demos/__init__.py +++ b/src/lightning/pytorch/demos/__init__.py @@ -1,2 +1,15 @@ -from lightning.pytorch.demos.lstm import LightningLSTM, SequenceSampler, SimpleLSTM # noqa: F401 -from lightning.pytorch.demos.transformer import LightningTransformer, Transformer, WikiText2 # noqa: F401 +from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel, DemoModel +from lightning.pytorch.demos.lstm import LightningLSTM, SequenceSampler, SimpleLSTM +from lightning.pytorch.demos.transformer import LightningTransformer, Transformer, WikiText2 + +__all__ = [ + "LightningLSTM", + "SequenceSampler", + "SimpleLSTM", + "LightningTransformer", + "Transformer", + "WikiText2", + "BoringModel", + "BoringDataModule", + "DemoModel", +] diff --git a/src/lightning/pytorch/demos/boring_classes.py b/src/lightning/pytorch/demos/boring_classes.py index fd2660228146e..3855f31898b81 100644 --- a/src/lightning/pytorch/demos/boring_classes.py +++ b/src/lightning/pytorch/demos/boring_classes.py @@ -11,11 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Iterator, List, Optional, Tuple +from collections.abc import Iterable, Iterator +from typing import Any, Optional import torch import torch.nn as nn import torch.nn.functional as F +from lightning_utilities import apply_to_collection from torch import Tensor from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler @@ -35,7 +37,7 @@ def __init__(self, size: int, length: int): self.len = length self.data = torch.randn(length, size) - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: a = self.data[index] b = a + 2 return {"a": a, "b": b} @@ -134,7 +136,7 @@ def validation_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT: def test_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT: return {"y": self.step(batch)} - def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[LRScheduler]]: + def configure_optimizers(self) -> tuple[list[torch.optim.Optimizer], list[LRScheduler]]: optimizer = torch.optim.SGD(self.parameters(), lr=0.1) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) return [optimizer], [lr_scheduler] @@ -187,6 +189,86 @@ def predict_dataloader(self) -> DataLoader: return DataLoader(self.random_predict) +class BoringDataModuleNoLen(LightningDataModule): + """ + .. warning:: This is meant for testing/debugging and is experimental. + """ + + def __init__(self) -> None: + super().__init__() + + def setup(self, stage: str) -> None: + if stage == "fit": + self.random_train = RandomIterableDataset(32, 512) + + if stage in ("fit", "validate"): + self.random_val = RandomIterableDataset(32, 128) + + if stage == "test": + self.random_test = RandomIterableDataset(32, 256) + + if stage == "predict": + self.random_predict = RandomIterableDataset(32, 64) + + def train_dataloader(self) -> DataLoader: + return DataLoader(self.random_train) + + def val_dataloader(self) -> DataLoader: + return DataLoader(self.random_val) + + def test_dataloader(self) -> DataLoader: + return DataLoader(self.random_test) + + def predict_dataloader(self) -> DataLoader: + return DataLoader(self.random_predict) + + +class IterableBoringDataModule(LightningDataModule): + def __init__(self) -> None: + super().__init__() + + def setup(self, stage: str) -> None: + if stage == "fit": + self.train_datasets = [ + RandomDataset(4, 16), + RandomIterableDataset(4, 16), + ] + + if stage in ("fit", "validate"): + self.val_datasets = [ + RandomDataset(4, 32), + RandomIterableDataset(4, 32), + ] + + if stage == "test": + self.test_datasets = [ + RandomDataset(4, 64), + RandomIterableDataset(4, 64), + ] + + if stage == "predict": + self.predict_datasets = [ + RandomDataset(4, 128), + RandomIterableDataset(4, 128), + ] + + def train_dataloader(self) -> Iterable[DataLoader]: + combined_train = apply_to_collection(self.train_datasets, Dataset, lambda x: DataLoader(x)) + return combined_train + + def val_dataloader(self) -> DataLoader: + combined_val = apply_to_collection(self.val_datasets, Dataset, lambda x: DataLoader(x)) + return combined_val + + def test_dataloader(self) -> DataLoader: + combined_test = apply_to_collection(self.test_datasets, Dataset, lambda x: DataLoader(x)) + return combined_test + + def predict_dataloader(self) -> DataLoader: + combined_predict = apply_to_collection(self.predict_datasets, Dataset, lambda x: DataLoader(x)) + return combined_predict + + class ManualOptimBoringModel(BoringModel): """ .. warning:: This is meant for testing/debugging and is experimental. diff --git a/src/lightning/pytorch/demos/lstm.py b/src/lightning/pytorch/demos/lstm.py index 672b61ad0eff9..9432dd9acd1f8 100644 --- a/src/lightning/pytorch/demos/lstm.py +++ b/src/lightning/pytorch/demos/lstm.py @@ -5,7 +5,8 @@ """ -from typing import Iterator, List, Optional, Sized, Tuple +from collections.abc import Iterator, Sized +from typing import Optional import torch import torch.nn as nn @@ -37,14 +38,14 @@ def init_weights(self) -> None: nn.init.zeros_(self.decoder.bias) nn.init.uniform_(self.decoder.weight, -0.1, 0.1) - def forward(self, input: Tensor, hidden: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]: + def forward(self, input: Tensor, hidden: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]: emb = self.drop(self.encoder(input)) output, hidden = self.rnn(emb, hidden) output = self.drop(output) decoded = self.decoder(output).view(-1, self.vocab_size) return F.log_softmax(decoded, dim=1), hidden - def init_hidden(self, batch_size: int) -> Tuple[Tensor, Tensor]: + def init_hidden(self, batch_size: int) -> tuple[Tensor, Tensor]: weight = next(self.parameters()) return ( weight.new_zeros(self.nlayers, batch_size, self.nhid), @@ -52,14 +53,14 @@ def init_hidden(self, batch_size: int) -> Tuple[Tensor, Tensor]: ) -class SequenceSampler(Sampler[List[int]]): +class SequenceSampler(Sampler[list[int]]): def __init__(self, dataset: Sized, batch_size: int) -> None: super().__init__() self.dataset = dataset self.batch_size = batch_size self.chunk_size = len(self.dataset) // self.batch_size - def __iter__(self) -> Iterator[List[int]]: + def __iter__(self) -> Iterator[list[int]]: n = len(self.dataset) for i in range(self.chunk_size): yield list(range(i, n - (n % self.batch_size), self.chunk_size)) @@ -72,12 +73,12 @@ class LightningLSTM(LightningModule): def __init__(self, vocab_size: int = 33278): super().__init__() self.model = SimpleLSTM(vocab_size=vocab_size) - self.hidden: Optional[Tuple[Tensor, Tensor]] = None + self.hidden: Optional[tuple[Tensor, Tensor]] = None def on_train_epoch_end(self) -> None: self.hidden = None - def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: + def training_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> Tensor: input, target = batch if self.hidden is None: self.hidden = self.model.init_hidden(input.size(0)) diff --git a/src/lightning/pytorch/demos/mnist_datamodule.py b/src/lightning/pytorch/demos/mnist_datamodule.py index 992527ab67296..73f46d4dc0986 100644 --- a/src/lightning/pytorch/demos/mnist_datamodule.py +++ b/src/lightning/pytorch/demos/mnist_datamodule.py @@ -16,7 +16,8 @@ import random import time import urllib -from typing import Any, Callable, Optional, Sized, Tuple, Union +from collections.abc import Sized +from typing import Any, Callable, Optional, Union from urllib.error import HTTPError from warnings import warn @@ -63,7 +64,7 @@ def __init__( data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME self.data, self.targets = self._try_load(os.path.join(self.cached_folder_path, data_file)) - def __getitem__(self, idx: int) -> Tuple[Tensor, int]: + def __getitem__(self, idx: int) -> tuple[Tensor, int]: img = self.data[idx].float().unsqueeze(0) target = int(self.targets[idx]) @@ -99,7 +100,7 @@ def _download(self, data_folder: str) -> None: urllib.request.urlretrieve(url, fpath) # noqa: S310 @staticmethod - def _try_load(path_data: str, trials: int = 30, delta: float = 1.0) -> Tuple[Tensor, Tensor]: + def _try_load(path_data: str, trials: int = 30, delta: float = 1.0) -> tuple[Tensor, Tensor]: """Resolving loading from the same time from multiple concurrent processes.""" res, exception = None, None assert trials, "at least some trial has to be set" diff --git a/src/lightning/pytorch/demos/transformer.py b/src/lightning/pytorch/demos/transformer.py index ac83b5539f249..eca86b4cb4dc7 100644 --- a/src/lightning/pytorch/demos/transformer.py +++ b/src/lightning/pytorch/demos/transformer.py @@ -8,7 +8,7 @@ import math import os from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import Optional import torch import torch.nn as nn @@ -88,7 +88,7 @@ def forward(self, x: Tensor) -> Tensor: # TODO: Could make this a `nn.Parameter` with `requires_grad=False` self.pe = self._init_pos_encoding(device=x.device) - x = x + self.pe[: x.size(0), :] + x = x + self.pe[:, x.size(1)] return self.dropout(x) def _init_pos_encoding(self, device: torch.device) -> Tensor: @@ -97,7 +97,7 @@ def _init_pos_encoding(self, device: torch.device) -> Tensor: div_term = torch.exp(torch.arange(0, self.dim, 2, device=device).float() * (-math.log(10000.0) / self.dim)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0).transpose(0, 1) + pe = pe.unsqueeze(0) return pe @@ -119,7 +119,7 @@ def vocab_size(self) -> int: def __len__(self) -> int: return len(self.data) // self.block_size - 1 - def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: + def __getitem__(self, index: int) -> tuple[Tensor, Tensor]: start = index * self.block_size end = start + self.block_size inputs = self.data[start:end] @@ -143,8 +143,8 @@ def download(destination: Path) -> None: class Dictionary: def __init__(self) -> None: - self.word2idx: Dict[str, int] = {} - self.idx2word: List[str] = [] + self.word2idx: dict[str, int] = {} + self.idx2word: list[str] = [] def add_word(self, word: str) -> int: if word not in self.word2idx: @@ -156,7 +156,7 @@ def __len__(self) -> int: return len(self.idx2word) -def tokenize(path: Path) -> Tuple[Tensor, Dictionary]: +def tokenize(path: Path) -> tuple[Tensor, Dictionary]: dictionary = Dictionary() assert os.path.exists(path) @@ -169,10 +169,10 @@ def tokenize(path: Path) -> Tuple[Tensor, Dictionary]: # Tokenize file content with open(path, encoding="utf8") as f: - idss: List[Tensor] = [] + idss: list[Tensor] = [] for line in f: words = line.split() + [""] - ids: List[int] = [] + ids: list[int] = [] for word in words: ids.append(dictionary.word2idx[word]) idss.append(torch.tensor(ids).type(torch.int64)) @@ -188,7 +188,7 @@ def __init__(self, vocab_size: int = 33278) -> None: def forward(self, inputs: Tensor, target: Tensor) -> Tensor: return self.model(inputs, target) - def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: + def training_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> Tensor: inputs, target = batch output = self(inputs, target) loss = torch.nn.functional.nll_loss(output, target.view(-1)) diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py index 277af5c85f539..7ff7605249f2f 100644 --- a/src/lightning/pytorch/loggers/comet.py +++ b/src/lightning/pytorch/loggers/comet.py @@ -19,23 +19,27 @@ import logging import os from argparse import Namespace -from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, Literal, Optional, Union from lightning_utilities.core.imports import RequirementCache from torch import Tensor from torch.nn import Module from typing_extensions import override -from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict +from lightning.fabric.utilities.logger import _convert_params +from lightning.fabric.utilities.rank_zero import _get_rank from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment -from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.rank_zero import rank_zero_only if TYPE_CHECKING: from comet_ml import ExistingExperiment, Experiment, OfflineExperiment log = logging.getLogger(__name__) -_COMET_AVAILABLE = RequirementCache("comet-ml>=3.31.0", module="comet_ml") +_COMET_AVAILABLE = RequirementCache("comet-ml>=3.44.4", module="comet_ml") + +FRAMEWORK_NAME = "pytorch-lightning" +comet_experiment = Union["Experiment", "ExistingExperiment", "OfflineExperiment"] class CometLogger(Logger): @@ -60,13 +64,11 @@ class CometLogger(Logger): # arguments made to CometLogger are passed on to the comet_ml.Experiment class comet_logger = CometLogger( - api_key=os.environ.get("COMET_API_KEY"), + api_key=os.environ.get("COMET_API_KEY"), # Optional workspace=os.environ.get("COMET_WORKSPACE"), # Optional - save_dir=".", # Optional - project_name="default_project", # Optional - rest_api_key=os.environ.get("COMET_REST_API_KEY"), # Optional + project="default_project", # Optional experiment_key=os.environ.get("COMET_EXPERIMENT_KEY"), # Optional - experiment_name="lightning_logs", # Optional + name="lightning_logs", # Optional ) trainer = Trainer(logger=comet_logger) @@ -78,11 +80,10 @@ class CometLogger(Logger): # arguments made to CometLogger are passed on to the comet_ml.Experiment class comet_logger = CometLogger( - save_dir=".", workspace=os.environ.get("COMET_WORKSPACE"), # Optional - project_name="default_project", # Optional - rest_api_key=os.environ.get("COMET_REST_API_KEY"), # Optional - experiment_name="lightning_logs", # Optional + project="default_project", # Optional + name="lightning_logs", # Optional + online=False ) trainer = Trainer(logger=comet_logger) @@ -106,6 +107,9 @@ def __init__(self, *args, **kwarg): # log multiple parameters logger.log_hyperparams({"batch_size": 16, "learning_rate": 0.001}) + # log nested parameters + logger.log_hyperparams({"specific": {'param': {'subparam': "value"}}}) + **Log Metrics:** .. code-block:: python @@ -116,6 +120,9 @@ def __init__(self, *args, **kwarg): # add multiple metrics logger.log_metrics({"train/loss": 0.001, "val/loss": 0.002}) + # add nested metrics + logger.log_metrics({"specific": {'metric': {'submetric': "value"}}}) + **Access the Comet Experiment object:** You can gain access to the underlying Comet @@ -166,100 +173,134 @@ def __init__(self, *args, **kwarg): - `Comet Documentation `__ Args: - api_key: Required in online mode. API key, found on Comet.ml. If not given, this - will be loaded from the environment variable COMET_API_KEY or ~/.comet.config - if either exists. - save_dir: Required in offline mode. The path for the directory to save local - comet logs. If given, this also sets the directory for saving checkpoints. - project_name: Optional. Send your experiment to a specific project. - Otherwise will be sent to Uncategorized Experiments. - If the project name does not already exist, Comet.ml will create a new project. - rest_api_key: Optional. Rest API key found in Comet.ml settings. - This is used to determine version number - experiment_name: Optional. String representing the name for this particular experiment on Comet.ml. - experiment_key: Optional. If set, restores from existing experiment. - offline: If api_key and save_dir are both given, this determines whether - the experiment will be in online or offline mode. This is useful if you use - save_dir to control the checkpoints directory and have a ~/.comet.config - file but still want to run offline experiments. - prefix: A string to put at the beginning of metric keys. - \**kwargs: Additional arguments like `workspace`, `log_code`, etc. used by + api_key: Comet API key. It's recommended to configure the API Key with `comet login`. + workspace: Comet workspace name. If not provided, uses the default workspace. + project: Comet project name. Defaults to `Uncategorized`. + experiment_key: The Experiment identifier to be used for logging. This is used either to append + data to an Existing Experiment or to control the key of new experiments (for example to match another + identifier). Must be an alphanumeric string whose length is between 32 and 50 characters. + mode: Control how the Comet experiment is started. + * ``"get_or_create"``: Starts a fresh experiment if required, or persists logging to an existing one. + * ``"get"``: Continue logging to an existing experiment identified by the ``experiment_key`` value. + * ``"create"``: Always creates of a new experiment, useful for HPO sweeps. + online: If True, the data will be logged to Comet server, otherwise it will be stored + locally in an offline experiment. Default is ``True``. + prefix: The prefix to add to names of the logged metrics. + example: prefix=`exp1`, then metric name will be logged as `exp1_metric_name` + **kwargs: Additional arguments like `name`, `log_code`, `offline_directory` etc. used by :class:`CometExperiment` can be passed as keyword arguments in this logger. Raises: ModuleNotFoundError: If required Comet package is not installed on the device. - MisconfigurationException: - If neither ``api_key`` nor ``save_dir`` are passed as arguments. """ - LOGGER_JOIN_CHAR = "-" - def __init__( self, + *, api_key: Optional[str] = None, - save_dir: Optional[str] = None, - project_name: Optional[str] = None, - rest_api_key: Optional[str] = None, - experiment_name: Optional[str] = None, + workspace: Optional[str] = None, + project: Optional[str] = None, experiment_key: Optional[str] = None, - offline: bool = False, - prefix: str = "", + mode: Optional[Literal["get_or_create", "get", "create"]] = None, + online: Optional[bool] = None, + prefix: Optional[str] = None, **kwargs: Any, ): if not _COMET_AVAILABLE: raise ModuleNotFoundError(str(_COMET_AVAILABLE)) + super().__init__() - self._experiment = None - self._save_dir: Optional[str] - self.rest_api_key: Optional[str] + + ################################################## + # HANDLE PASSED OLD TYPE PARAMS + + # handle old "experiment_name" param + if "experiment_name" in kwargs: + log.warning("The parameter `experiment_name` is deprecated, please use `name` instead.") + experiment_name = kwargs.pop("experiment_name") + + if "name" not in kwargs: + kwargs["name"] = experiment_name + else: + log.warning("You specified both `experiment_name` and `name` parameters, please use `name` only") + + # handle old "project_name" param + if "project_name" in kwargs: + log.warning("The parameter `project_name` is deprecated, please use `project` instead.") + if project is None: + project = kwargs.pop("project_name") + else: + log.warning("You specified both `project_name` and `project` parameters, please use `project` only") + + # handle old "offline" experiment flag + if "offline" in kwargs: + log.warning("The parameter `offline is deprecated, please use `online` instead.") + if online is None: + online = kwargs.pop("offline") + else: + log.warning("You specified both `offline` and `online` parameters, please use `online` only") + + # handle old "save_dir" param + if "save_dir" in kwargs: + log.warning("The parameter `save_dir` is deprecated, please use `offline_directory` instead.") + if "offline_directory" not in kwargs: + kwargs["offline_directory"] = kwargs.pop("save_dir") + else: + log.warning( + "You specified both `save_dir` and `offline_directory` parameters, " + "please use `offline_directory` only" + ) + ################################################## + + self._api_key: Optional[str] = api_key + self._experiment: Optional[comet_experiment] = None + self._workspace: Optional[str] = workspace + self._mode: Optional[Literal["get_or_create", "get", "create"]] = mode + self._online: Optional[bool] = online + self._project_name: Optional[str] = project + self._experiment_key: Optional[str] = experiment_key + self._prefix: Optional[str] = prefix + self._kwargs: dict[str, Any] = kwargs # needs to be set before the first `comet_ml` import + # because comet_ml imported after another machine learning libraries (Torch) os.environ["COMET_DISABLE_AUTO_LOGGING"] = "1" import comet_ml - # Determine online or offline mode based on which arguments were passed to CometLogger - api_key = api_key or comet_ml.config.get_api_key(None, comet_ml.config.get_config()) - - if api_key is not None and save_dir is not None: - self.mode = "offline" if offline else "online" - self.api_key = api_key - self._save_dir = save_dir - elif api_key is not None: - self.mode = "online" - self.api_key = api_key - self._save_dir = None - elif save_dir is not None: - self.mode = "offline" - self._save_dir = save_dir - else: - # If neither api_key nor save_dir are passed as arguments, raise an exception - raise MisconfigurationException("CometLogger requires either api_key or save_dir during initialization.") - - log.info(f"CometLogger will be initialized in {self.mode} mode") - - self._project_name: Optional[str] = project_name - self._experiment_key: Optional[str] = experiment_key - self._experiment_name: Optional[str] = experiment_name - self._prefix: str = prefix - self._kwargs: Any = kwargs - self._future_experiment_key: Optional[str] = None + self._comet_config = comet_ml.ExperimentConfig(**self._kwargs) - if rest_api_key is not None: - from comet_ml.api import API + # create real experiment only on main node/process (when strategy=auto/ddp) + if _get_rank() is not None and _get_rank() != 0: + return + + self._create_experiment() + + def _create_experiment(self) -> None: + import comet_ml - # Comet.ml rest API, used to determine version number - self.rest_api_key = rest_api_key - self.comet_api = API(self.rest_api_key) - else: - self.rest_api_key = None - self.comet_api = None + self._experiment = comet_ml.start( + api_key=self._api_key, + workspace=self._workspace, + project=self._project_name, + experiment_key=self._experiment_key, + mode=self._mode, + online=self._online, + experiment_config=self._comet_config, + ) + + if self._experiment is None: + raise comet_ml.exceptions.ExperimentNotFound("Failed to create Comet experiment.") + + self._experiment_key = self._experiment.get_key() + self._project_name = self._experiment.project_name + self._experiment.log_other("Created from", FRAMEWORK_NAME) @property @rank_zero_experiment - def experiment(self) -> Union["Experiment", "ExistingExperiment", "OfflineExperiment"]: + def experiment(self) -> comet_experiment: r"""Actual Comet object. To use Comet features in your :class:`~lightning.pytorch.core.LightningModule` do the following. @@ -268,82 +309,56 @@ def experiment(self) -> Union["Experiment", "ExistingExperiment", "OfflineExperi self.logger.experiment.some_comet_function() """ - if self._experiment is not None and self._experiment.alive: - return self._experiment - - if self._future_experiment_key is not None: - os.environ["COMET_EXPERIMENT_KEY"] = self._future_experiment_key - - from comet_ml import ExistingExperiment, Experiment, OfflineExperiment - - try: - if self.mode == "online": - if self._experiment_key is None: - self._experiment = Experiment(api_key=self.api_key, project_name=self._project_name, **self._kwargs) - self._experiment_key = self._experiment.get_key() - else: - self._experiment = ExistingExperiment( - api_key=self.api_key, - project_name=self._project_name, - previous_experiment=self._experiment_key, - **self._kwargs, - ) - else: - self._experiment = OfflineExperiment( - offline_directory=self.save_dir, project_name=self._project_name, **self._kwargs - ) - self._experiment.log_other("Created from", "pytorch-lightning") - finally: - if self._future_experiment_key is not None: - os.environ.pop("COMET_EXPERIMENT_KEY") - self._future_experiment_key = None - if self._experiment_name: - self._experiment.set_name(self._experiment_name) + # if by some chance there is no experiment created yet (for example, when strategy=ddp_spawn) + # then we will create a new one + if not self._experiment: + self._create_experiment() return self._experiment @override @rank_zero_only - def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: + def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: params = _convert_params(params) - params = _flatten_dict(params) - self.experiment.log_parameters(params) + self.experiment.__internal_api__log_parameters__( + parameters=params, + framework=FRAMEWORK_NAME, + flatten_nested=True, + source="manual", + ) @override @rank_zero_only def log_metrics(self, metrics: Mapping[str, Union[Tensor, float]], step: Optional[int] = None) -> None: assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" - # Comet.ml expects metrics to be a dictionary of detached tensors on CPU + # Comet.com expects metrics to be a dictionary of detached tensors on CPU metrics_without_epoch = metrics.copy() for key, val in metrics_without_epoch.items(): if isinstance(val, Tensor): metrics_without_epoch[key] = val.cpu().detach() epoch = metrics_without_epoch.pop("epoch", None) - metrics_without_epoch = _add_prefix(metrics_without_epoch, self._prefix, self.LOGGER_JOIN_CHAR) - self.experiment.log_metrics(metrics_without_epoch, step=step, epoch=epoch) - - def reset_experiment(self) -> None: - self._experiment = None + self.experiment.__internal_api__log_metrics__( + metrics_without_epoch, + step=step, + epoch=epoch, + prefix=self._prefix, + framework=FRAMEWORK_NAME, + ) @override @rank_zero_only def finalize(self, status: str) -> None: - r"""When calling ``self.experiment.end()``, that experiment won't log any more data to Comet. That's why, if you - need to log any more data, you need to create an ExistingCometExperiment. For example, to log data when testing - your model after training, because when training is finalized :meth:`CometLogger.finalize` is called. - - This happens automatically in the :meth:`~CometLogger.experiment` property, when - ``self._experiment`` is set to ``None``, i.e. ``self.reset_experiment()``. - - """ + """We will not end experiment (will not call self._experiment.end()) here to have an ability to continue using + it after training is complete but instead of ending we will upload/save all the data.""" if self._experiment is None: # When using multiprocessing, finalize() should be a no-op on the main process, as no experiment has been # initialized there return - self.experiment.end() - self.reset_experiment() + + # just save the data + self.experiment.flush() @property @override @@ -354,69 +369,39 @@ def save_dir(self) -> Optional[str]: The path to the save directory. """ - return self._save_dir + return self._comet_config.offline_directory @property @override - def name(self) -> str: + def name(self) -> Optional[str]: """Gets the project name. Returns: - The project name if it is specified, else "comet-default". + The project name if it is specified. """ - # Don't create an experiment if we don't have one - if self._experiment is not None and self._experiment.project_name is not None: - return self._experiment.project_name - - if self._project_name is not None: - return self._project_name - - return "comet-default" + return self._project_name @property @override - def version(self) -> str: + def version(self) -> Optional[str]: """Gets the version. Returns: - The first one of the following that is set in the following order - - 1. experiment id. - 2. experiment key. - 3. "COMET_EXPERIMENT_KEY" environment variable. - 4. future experiment key. - - If none are present generates a new guid. + The experiment key if present """ # Don't create an experiment if we don't have one if self._experiment is not None: - return self._experiment.id - - if self._experiment_key is not None: - return self._experiment_key - - if "COMET_EXPERIMENT_KEY" in os.environ: - return os.environ["COMET_EXPERIMENT_KEY"] - - if self._future_experiment_key is not None: - return self._future_experiment_key - - import comet_ml - - # Pre-generate an experiment key - self._future_experiment_key = comet_ml.generate_guid() - - return self._future_experiment_key + return self._experiment.get_key() - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: state = self.__dict__.copy() # Save the experiment id in case an experiment object already exists, # this way we could create an ExistingExperiment pointing to the same # experiment - state["_experiment_key"] = self._experiment.id if self._experiment is not None else None + state["_experiment_key"] = self._experiment.get_key() if self._experiment is not None else None # Remove the experiment object as it contains hard to pickle objects # (like network connections), the experiment object will be recreated if @@ -427,4 +412,7 @@ def __getstate__(self) -> Dict[str, Any]: @override def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None: if self._experiment is not None: - self._experiment.set_model_graph(model) + self._experiment.__internal_api__set_model_graph__( + graph=model, + framework=FRAMEWORK_NAME, + ) diff --git a/src/lightning/pytorch/loggers/csv_logs.py b/src/lightning/pytorch/loggers/csv_logs.py index caca0c181c6ff..8606264dc3cdb 100644 --- a/src/lightning/pytorch/loggers/csv_logs.py +++ b/src/lightning/pytorch/loggers/csv_logs.py @@ -21,7 +21,7 @@ import os from argparse import Namespace -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from typing_extensions import override @@ -52,9 +52,9 @@ class ExperimentWriter(_FabricExperimentWriter): def __init__(self, log_dir: str) -> None: super().__init__(log_dir=log_dir) - self.hparams: Dict[str, Any] = {} + self.hparams: dict[str, Any] = {} - def log_hparams(self, params: Dict[str, Any]) -> None: + def log_hparams(self, params: dict[str, Any]) -> None: """Record hparams.""" self.hparams.update(params) @@ -144,7 +144,7 @@ def save_dir(self) -> str: @override @rank_zero_only - def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: + def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: params = _convert_params(params) self.experiment.log_hparams(params) diff --git a/src/lightning/pytorch/loggers/logger.py b/src/lightning/pytorch/loggers/logger.py index 40e8ed8c4a13e..668fe39cb67d2 100644 --- a/src/lightning/pytorch/loggers/logger.py +++ b/src/lightning/pytorch/loggers/logger.py @@ -18,7 +18,8 @@ import statistics from abc import ABC from collections import defaultdict -from typing import Any, Callable, Dict, Mapping, Optional, Sequence +from collections.abc import Mapping, Sequence +from typing import Any, Callable, Optional from typing_extensions import override @@ -101,7 +102,7 @@ def merge_dicts( # pragma: no cover dicts: Sequence[Mapping], agg_key_funcs: Optional[Mapping] = None, default_func: Callable[[Sequence[float]], float] = statistics.mean, -) -> Dict: +) -> dict: """Merge a sequence with dictionaries into one dictionary by aggregating the same keys with some given function. Args: @@ -137,7 +138,7 @@ def merge_dicts( # pragma: no cover """ agg_key_funcs = agg_key_funcs or {} keys = list(functools.reduce(operator.or_, [set(d.keys()) for d in dicts])) - d_out: Dict = defaultdict(dict) + d_out: dict = defaultdict(dict) for k in keys: fn = agg_key_funcs.get(k) values_to_agg = [v for v in [d_in.get(k) for d_in in dicts] if v is not None] diff --git a/src/lightning/pytorch/loggers/mlflow.py b/src/lightning/pytorch/loggers/mlflow.py index 4437d5f8a7c76..e3d99987b7f58 100644 --- a/src/lightning/pytorch/loggers/mlflow.py +++ b/src/lightning/pytorch/loggers/mlflow.py @@ -21,9 +21,10 @@ import re import tempfile from argparse import Namespace +from collections.abc import Mapping from pathlib import Path from time import time -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Mapping, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union import yaml from lightning_utilities.core.imports import RequirementCache @@ -117,7 +118,7 @@ def __init__( experiment_name: str = "lightning_logs", run_name: Optional[str] = None, tracking_uri: Optional[str] = os.getenv("MLFLOW_TRACKING_URI"), - tags: Optional[Dict[str, Any]] = None, + tags: Optional[dict[str, Any]] = None, save_dir: Optional[str] = "./mlruns", log_model: Literal[True, False, "all"] = False, prefix: str = "", @@ -140,7 +141,7 @@ def __init__( self._run_id = run_id self.tags = tags self._log_model = log_model - self._logged_model_time: Dict[str, float] = {} + self._logged_model_time: dict[str, float] = {} self._checkpoint_callback: Optional[ModelCheckpoint] = None self._prefix = prefix self._artifact_location = artifact_location @@ -227,7 +228,7 @@ def experiment_id(self) -> Optional[str]: @override @rank_zero_only - def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: + def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: params = _convert_params(params) params = _flatten_dict(params) @@ -249,7 +250,7 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) from mlflow.entities import Metric metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) - metrics_list: List[Metric] = [] + metrics_list: list[Metric] = [] timestamp_ms = int(time() * 1000) for k, v in metrics.items(): @@ -299,7 +300,7 @@ def save_dir(self) -> Optional[str]: """ if self._tracking_uri.startswith(LOCAL_FILE_URI_PREFIX): - return self._tracking_uri.lstrip(LOCAL_FILE_URI_PREFIX) + return self._tracking_uri[len(LOCAL_FILE_URI_PREFIX) :] return None @property @@ -360,7 +361,7 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> Non aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"] # Artifact path on mlflow - artifact_path = f"model/checkpoints/{Path(p).stem}" + artifact_path = Path(p).stem # Log the checkpoint self.experiment.log_artifact(self._run_id, p, artifact_path) diff --git a/src/lightning/pytorch/loggers/neptune.py b/src/lightning/pytorch/loggers/neptune.py index 691dbe0ba2ff7..a363f589b29b4 100644 --- a/src/lightning/pytorch/loggers/neptune.py +++ b/src/lightning/pytorch/loggers/neptune.py @@ -20,8 +20,9 @@ import logging import os from argparse import Namespace +from collections.abc import Generator from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Set, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union from lightning_utilities.core.imports import RequirementCache from torch import Tensor @@ -286,8 +287,8 @@ def _retrieve_run_data(self) -> None: self._run_name = "offline-name" @property - def _neptune_init_args(self) -> Dict: - args: Dict = {} + def _neptune_init_args(self) -> dict: + args: dict = {} # Backward compatibility in case of previous version retrieval with contextlib.suppress(AttributeError): args = self._neptune_run_kwargs @@ -337,13 +338,13 @@ def _verify_input_arguments( " parameters." ) - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: state = self.__dict__.copy() # Run instance can't be pickled state["_run_instance"] = None return state - def __setstate__(self, state: Dict[str, Any]) -> None: + def __setstate__(self, state: dict[str, Any]) -> None: import neptune self.__dict__ = state @@ -395,7 +396,7 @@ def run(self) -> "Run": @override @rank_zero_only @_catch_inactive - def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: + def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: r"""Log hyperparameters to the run. Hyperparameters will be logged under the "/hyperparams" namespace. @@ -443,7 +444,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: @override @rank_zero_only @_catch_inactive - def log_metrics(self, metrics: Dict[str, Union[Tensor, float]], step: Optional[int] = None) -> None: + def log_metrics(self, metrics: dict[str, Union[Tensor, float]], step: Optional[int] = None) -> None: """Log metrics (numeric values) in Neptune runs. Args: @@ -563,16 +564,16 @@ def _get_full_model_name(model_path: str, checkpoint_callback: Checkpoint) -> st return model_path.replace(os.sep, "/") @classmethod - def _get_full_model_names_from_exp_structure(cls, exp_structure: Dict[str, Any], namespace: str) -> Set[str]: + def _get_full_model_names_from_exp_structure(cls, exp_structure: dict[str, Any], namespace: str) -> set[str]: """Returns all paths to properties which were already logged in `namespace`""" - structure_keys: List[str] = namespace.split(cls.LOGGER_JOIN_CHAR) + structure_keys: list[str] = namespace.split(cls.LOGGER_JOIN_CHAR) for key in structure_keys: exp_structure = exp_structure[key] uploaded_models_dict = exp_structure return set(cls._dict_paths(uploaded_models_dict)) @classmethod - def _dict_paths(cls, d: Dict[str, Any], path_in_build: Optional[str] = None) -> Generator: + def _dict_paths(cls, d: dict[str, Any], path_in_build: Optional[str] = None) -> Generator: for k, v in d.items(): path = f"{path_in_build}/{k}" if path_in_build is not None else k if not isinstance(v, dict): diff --git a/src/lightning/pytorch/loggers/tensorboard.py b/src/lightning/pytorch/loggers/tensorboard.py index 88e026f6945e0..f9cc41c67045c 100644 --- a/src/lightning/pytorch/loggers/tensorboard.py +++ b/src/lightning/pytorch/loggers/tensorboard.py @@ -18,7 +18,7 @@ import os from argparse import Namespace -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import override @@ -108,7 +108,7 @@ def __init__( f"{str(_TENSORBOARD_AVAILABLE)}" ) self._log_graph = log_graph and _TENSORBOARD_AVAILABLE - self.hparams: Union[Dict[str, Any], Namespace] = {} + self.hparams: Union[dict[str, Any], Namespace] = {} @property @override @@ -153,15 +153,19 @@ def save_dir(self) -> str: @override @rank_zero_only def log_hyperparams( - self, params: Union[Dict[str, Any], Namespace], metrics: Optional[Dict[str, Any]] = None + self, + params: Union[dict[str, Any], Namespace], + metrics: Optional[dict[str, Any]] = None, + step: Optional[int] = None, ) -> None: """Record hyperparameters. TensorBoard logs with and without saved hyperparameters are incompatible, the hyperparameters are then not displayed in the TensorBoard. Please delete or move the previously saved logs to display the new ones with hyperparameters. Args: - params: a dictionary-like container with the hyperparameters + params: A dictionary-like container with the hyperparameters metrics: Dictionary with metric names as keys and measured quantities as values + step: Optional global step number for the logged metrics """ if _OMEGACONF_AVAILABLE: @@ -175,7 +179,7 @@ def log_hyperparams( else: self.hparams.update(params) - return super().log_hyperparams(params=params, metrics=metrics) + return super().log_hyperparams(params=params, metrics=metrics, step=step) @override @rank_zero_only diff --git a/src/lightning/pytorch/loggers/utilities.py b/src/lightning/pytorch/loggers/utilities.py index 2ff9cbd24eca5..e1752c67d9183 100644 --- a/src/lightning/pytorch/loggers/utilities.py +++ b/src/lightning/pytorch/loggers/utilities.py @@ -14,7 +14,7 @@ """Utilities for loggers.""" from pathlib import Path -from typing import Any, List, Tuple, Union +from typing import Any, Union from torch import Tensor @@ -22,14 +22,14 @@ from lightning.pytorch.callbacks import Checkpoint -def _version(loggers: List[Any], separator: str = "_") -> Union[int, str]: +def _version(loggers: list[Any], separator: str = "_") -> Union[int, str]: if len(loggers) == 1: return loggers[0].version # Concatenate versions together, removing duplicates and preserving order return separator.join(dict.fromkeys(str(logger.version) for logger in loggers)) -def _scan_checkpoints(checkpoint_callback: Checkpoint, logged_model_time: dict) -> List[Tuple[float, str, float, str]]: +def _scan_checkpoints(checkpoint_callback: Checkpoint, logged_model_time: dict) -> list[tuple[float, str, float, str]]: """Return the checkpoints to be logged. Args: @@ -69,6 +69,9 @@ def _log_hyperparams(trainer: "pl.Trainer") -> None: lightning_hparams = pl_module.hparams_initial inconsistent_keys = [] for key in lightning_hparams.keys() & datamodule_hparams.keys(): + if key == "_class_path": + # Skip LightningCLI's internal hparam + continue lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key] if ( type(lm_val) != type(dm_val) @@ -88,6 +91,10 @@ def _log_hyperparams(trainer: "pl.Trainer") -> None: elif datamodule_log_hyperparams: hparams_initial = trainer.datamodule.hparams_initial + # Don't log LightningCLI's internal hparam + if hparams_initial is not None: + hparams_initial = {k: v for k, v in hparams_initial.items() if k != "_class_path"} + for logger in trainer.loggers: if hparams_initial is not None: logger.log_hyperparams(hparams_initial) diff --git a/src/lightning/pytorch/loggers/wandb.py b/src/lightning/pytorch/loggers/wandb.py index c5d995bff35a5..2429748f73179 100644 --- a/src/lightning/pytorch/loggers/wandb.py +++ b/src/lightning/pytorch/loggers/wandb.py @@ -18,8 +18,9 @@ import os from argparse import Namespace +from collections.abc import Mapping from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional, Union +from typing import TYPE_CHECKING, Any, Literal, Optional, Union import torch.nn as nn from lightning_utilities.core.imports import RequirementCache @@ -48,7 +49,7 @@ class WandbLogger(Logger): - r"""Log using `Weights and Biases `_. + r"""Log using `Weights and Biases `_. **Installation and set-up** @@ -253,7 +254,7 @@ def any_lightning_module_function_or_hook(self): See Also: - `Demo in Google Colab `__ with hyperparameter search and model logging - - `W&B Documentation `__ + - `W&B Documentation `__ Args: name: Display name for the run. @@ -320,7 +321,7 @@ def __init__( self._log_model = log_model self._prefix = prefix self._experiment = experiment - self._logged_model_time: Dict[str, float] = {} + self._logged_model_time: dict[str, float] = {} self._checkpoint_callback: Optional[ModelCheckpoint] = None # paths are processed as strings @@ -332,7 +333,7 @@ def __init__( project = project or os.environ.get("WANDB_PROJECT", "lightning_logs") # set wandb init arguments - self._wandb_init: Dict[str, Any] = { + self._wandb_init: dict[str, Any] = { "name": name, "project": project, "dir": save_dir or dir, @@ -348,7 +349,7 @@ def __init__( self._id = self._wandb_init.get("id") self._checkpoint_name = checkpoint_name - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: import wandb # Hack: If the 'spawn' launch method is used, the logger will get pickled and this `__getstate__` gets called. @@ -421,7 +422,7 @@ def watch( @override @rank_zero_only - def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: + def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: params = _convert_params(params) params = _sanitize_callable_params(params) params = _convert_json_serializable(params) @@ -442,8 +443,8 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) def log_table( self, key: str, - columns: Optional[List[str]] = None, - data: Optional[List[List[Any]]] = None, + columns: Optional[list[str]] = None, + data: Optional[list[list[Any]]] = None, dataframe: Any = None, step: Optional[int] = None, ) -> None: @@ -461,8 +462,8 @@ def log_table( def log_text( self, key: str, - columns: Optional[List[str]] = None, - data: Optional[List[List[str]]] = None, + columns: Optional[list[str]] = None, + data: Optional[list[list[str]]] = None, dataframe: Any = None, step: Optional[int] = None, ) -> None: @@ -475,7 +476,7 @@ def log_text( self.log_table(key, columns, data, dataframe, step) @rank_zero_only - def log_image(self, key: str, images: List[Any], step: Optional[int] = None, **kwargs: Any) -> None: + def log_image(self, key: str, images: list[Any], step: Optional[int] = None, **kwargs: Any) -> None: """Log images (tensors, numpy arrays, PIL Images or file paths). Optional kwargs are lists passed to each image (ex: caption, masks, boxes). @@ -495,7 +496,7 @@ def log_image(self, key: str, images: List[Any], step: Optional[int] = None, **k self.log_metrics(metrics, step) # type: ignore[arg-type] @rank_zero_only - def log_audio(self, key: str, audios: List[Any], step: Optional[int] = None, **kwargs: Any) -> None: + def log_audio(self, key: str, audios: list[Any], step: Optional[int] = None, **kwargs: Any) -> None: r"""Log audios (numpy arrays, or file paths). Args: @@ -521,7 +522,7 @@ def log_audio(self, key: str, audios: List[Any], step: Optional[int] = None, **k self.log_metrics(metrics, step) # type: ignore[arg-type] @rank_zero_only - def log_video(self, key: str, videos: List[Any], step: Optional[int] = None, **kwargs: Any) -> None: + def log_video(self, key: str, videos: list[Any], step: Optional[int] = None, **kwargs: Any) -> None: """Log videos (numpy arrays, or file paths). Args: diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py index 0ab3901cf072d..d007466ee3b1c 100644 --- a/src/lightning/pytorch/loops/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -15,7 +15,9 @@ import shutil import sys from collections import ChainMap, OrderedDict, defaultdict -from typing import Any, DefaultDict, Iterable, Iterator, List, Optional, Tuple, Union +from collections.abc import Iterable, Iterator +from dataclasses import dataclass +from typing import Any, Optional, Union from lightning_utilities.core.apply_func import apply_to_collection from torch import Tensor @@ -45,6 +47,12 @@ from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature +@dataclass +class RestartStage: + NONE = "none" + RESTARTED_MID_EVALUATION = "restarted_mid_evaluation" + + class _EvaluationLoop(_Loop): """Top-level loop where validation/testing starts.""" @@ -60,19 +68,20 @@ def __init__( self.verbose = verbose self.inference_mode = inference_mode self.batch_progress = _BatchProgress() # across dataloaders - self._max_batches: List[Union[int, float]] = [] + self._max_batches: list[Union[int, float]] = [] self._results = _ResultCollection(training=False) - self._logged_outputs: List[_OUT_DICT] = [] + self._logged_outputs: list[_OUT_DICT] = [] self._has_run: bool = False self._trainer_fn = trainer_fn self._stage = stage self._data_source = _DataLoaderSource(None, f"{stage.dataloader_prefix}_dataloader") self._combined_loader: Optional[CombinedLoader] = None self._data_fetcher: Optional[_DataFetcher] = None - self._seen_batches_per_dataloader: DefaultDict[int, int] = defaultdict(int) + self._seen_batches_per_dataloader: defaultdict[int, int] = defaultdict(int) self._last_val_dl_reload_epoch = float("-inf") self._module_mode = _ModuleMode() + self._restart_stage = RestartStage.NONE @property def num_dataloaders(self) -> int: @@ -82,7 +91,7 @@ def num_dataloaders(self) -> int: return len(combined_loader.flattened) @property - def max_batches(self) -> List[Union[int, float]]: + def max_batches(self) -> list[Union[int, float]]: """The max number of batches to run per dataloader.""" max_batches = self._max_batches if not self.trainer.sanity_checking: @@ -106,7 +115,7 @@ def _is_sequential(self) -> bool: return self._combined_loader._mode == "sequential" @_no_grad_context - def run(self) -> List[_OUT_DICT]: + def run(self) -> list[_OUT_DICT]: self.setup_data() if self.skip: return [] @@ -137,7 +146,7 @@ def run(self) -> List[_OUT_DICT]: # this needs to wrap the `*_step` call too (not just `next`) for `dataloader_iter` support break finally: - self._restarting = False + self.on_iteration_done() self._store_dataloader_outputs() return self.on_run_end() @@ -197,6 +206,24 @@ def setup_data(self) -> None: # this depends on the data used, so reset it too self._seen_batches_per_dataloader = defaultdict(int) + @property + def restarted_mid_evaluation(self) -> bool: + return self._restart_stage == RestartStage.RESTARTED_MID_EVALUATION + + def update_restart_stage(self) -> None: + if ( + self.restarting + and self.batch_progress.total.started == self.batch_progress.total.ready + and self.batch_progress.total.processed == self.batch_progress.total.started - 1 + and self.batch_progress.total.completed == self.batch_progress.total.processed + ): + self._restart_stage = RestartStage.RESTARTED_MID_EVALUATION + else: + self._restart_stage = RestartStage.NONE + + def reset_restart_stage(self) -> None: + self._restart_stage = RestartStage.NONE + def reset(self) -> None: """Resets the internal state of the loop.""" trainer = self.trainer @@ -236,6 +263,16 @@ def reset(self) -> None: data_fetcher._stop_profiler = self._on_after_fetch self._data_fetcher = data_fetcher + def increment_progress_to_evaluation_end(self) -> None: + self.setup_data() + if self.skip: + return + self.reset() + max_batch = int(max(self.max_batches)) + if max_batch == -1: + return + self.batch_progress.increment_by(max_batch, True) + def on_run_start(self) -> None: """Runs the ``_on_evaluation_model_eval``, ``_on_evaluation_start`` and ``_on_evaluation_epoch_start`` hooks.""" @@ -244,7 +281,7 @@ def on_run_start(self) -> None: self._on_evaluation_start() self._on_evaluation_epoch_start() - def on_run_end(self) -> List[_OUT_DICT]: + def on_run_end(self) -> list[_OUT_DICT]: """Runs the ``_on_evaluation_epoch_end`` hook.""" # if `done` returned True before any iterations were done, this won't have been called in `on_advance_end` self.trainer._logger_connector.epoch_end_reached() @@ -472,7 +509,7 @@ def _verify_dataloader_idx_requirement(self) -> None: ) @staticmethod - def _get_keys(data: dict) -> Iterable[Tuple[str, ...]]: + def _get_keys(data: dict) -> Iterable[tuple[str, ...]]: for k, v in data.items(): if isinstance(v, dict): for new_key in apply_to_collection(v, dict, _EvaluationLoop._get_keys): @@ -491,7 +528,7 @@ def _find_value(data: dict, target: Iterable[str]) -> Optional[Any]: return _EvaluationLoop._find_value(result, rest) @staticmethod - def _print_results(results: List[_OUT_DICT], stage: str) -> None: + def _print_results(results: list[_OUT_DICT], stage: str) -> None: # remove the dl idx suffix results = [{k.split("/dataloader_idx_")[0]: v for k, v in result.items()} for result in results] metrics_paths = {k for keys in apply_to_collection(results, dict, _EvaluationLoop._get_keys) for k in keys} @@ -508,7 +545,7 @@ def _print_results(results: List[_OUT_DICT], stage: str) -> None: term_size = shutil.get_terminal_size(fallback=(120, 30)).columns or 120 max_length = int(min(max(len(max(metrics_strs, key=len)), len(max(headers, key=len)), 25), term_size / 2)) - rows: List[List[Any]] = [[] for _ in metrics_paths] + rows: list[list[Any]] = [[] for _ in metrics_paths] for result in results: for metric, row in zip(metrics_paths, rows): diff --git a/src/lightning/pytorch/loops/fetchers.py b/src/lightning/pytorch/loops/fetchers.py index e699321a4d23e..92ec95a9e2f58 100644 --- a/src/lightning/pytorch/loops/fetchers.py +++ b/src/lightning/pytorch/loops/fetchers.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Iterator, List, Optional +from collections.abc import Iterator +from typing import Any, Optional from typing_extensions import override @@ -97,7 +98,7 @@ def __init__(self, prefetch_batches: int = 1) -> None: if prefetch_batches < 0: raise ValueError("`prefetch_batches` should at least be 0.") self.prefetch_batches = prefetch_batches - self.batches: List[Any] = [] + self.batches: list[Any] = [] @override def __iter__(self) -> "_PrefetchDataFetcher": diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index eb30e32757c9a..31d6724a043a3 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Dict, List, Optional, Union +from dataclasses import dataclass +from typing import Any, Optional, Union import torch from typing_extensions import override @@ -45,6 +46,15 @@ log = logging.getLogger(__name__) +@dataclass +class RestartStage: + NONE = "none" + RESTARTED_ON_EPOCH_START = "restarted_on_epoch_start" + RESTARTED_MID_EPOCH = "restarted_mid_epoch" + RESTARTED_ON_EPOCH_END = "restarted_on_epoch_end" + RESUMED_ON_EPOCH_END = "resumed_on_epoch_end" + + class _FitLoop(_Loop): """This loop is the top-level loop where training starts. @@ -94,9 +104,10 @@ def __init__( self._data_source = _DataLoaderSource(None, "train_dataloader") self._combined_loader: Optional[CombinedLoader] = None - self._combined_loader_states_to_load: List[Dict[str, Any]] = [] + self._combined_loader_states_to_load: list[dict[str, Any]] = [] self._data_fetcher: Optional[_DataFetcher] = None self._last_train_dl_reload_epoch = float("-inf") + self._restart_stage = RestartStage.NONE @property def total_batch_idx(self) -> int: @@ -204,9 +215,10 @@ def run(self) -> None: self.on_advance_start() self.advance() self.on_advance_end() - self._restarting = False except StopIteration: break + finally: + self.on_iteration_done() self._restarting = False self.on_run_end() @@ -302,14 +314,92 @@ def setup_data(self) -> None: category=PossibleUserWarning, ) + @property + def restarted_on_epoch_start(self) -> bool: + return self._restart_stage == RestartStage.RESTARTED_ON_EPOCH_START + + @property + def restarted_mid_epoch(self) -> bool: + return self._restart_stage == RestartStage.RESTARTED_MID_EPOCH + + @property + def restarted_on_epoch_end(self) -> bool: + return self._restart_stage == RestartStage.RESTARTED_ON_EPOCH_END + + @property + def resumed_on_epoch_end(self) -> bool: + # This case happens when restarting from last without validation at + # the end of epoch. In this case self.restarting is False. + return self._restart_stage == RestartStage.RESUMED_ON_EPOCH_END + + def update_restart_stage(self) -> None: + if ( + self.restarting + and self.epoch_progress.total.started == self.epoch_progress.total.ready - 1 + and self.epoch_progress.total.processed == self.epoch_progress.total.started + and self.epoch_progress.total.completed == self.epoch_progress.total.processed + ): + self._restart_stage = RestartStage.RESTARTED_ON_EPOCH_START + elif ( + self.restarting + and self.epoch_progress.total.started == self.epoch_progress.total.ready + and self.epoch_progress.total.processed == self.epoch_progress.total.started - 1 + and self.epoch_progress.total.completed == self.epoch_progress.total.processed + ): + self._restart_stage = RestartStage.RESTARTED_MID_EPOCH + elif ( + self.restarting + and self.epoch_progress.total.started == self.epoch_progress.total.ready + and self.epoch_progress.total.processed == self.epoch_progress.total.started + and self.epoch_progress.total.completed == self.epoch_progress.total.processed - 1 + ): + self._restart_stage = RestartStage.RESTARTED_ON_EPOCH_END + elif ( + self._loaded_from_state_dict + and self.epoch_progress.total.started == self.epoch_progress.total.ready + and self.epoch_progress.total.processed == self.epoch_progress.total.started + and self.epoch_progress.total.completed == self.epoch_progress.total.processed - 1 + ): + self._restart_stage = RestartStage.RESUMED_ON_EPOCH_END + else: + self._restart_stage = RestartStage.NONE + + self.epoch_loop.update_restart_stage() + + def reset_restart_stage(self) -> None: + self._restart_stage = RestartStage.NONE + def reset(self) -> None: """Resets the internal state of this loop.""" assert self.trainer.model is not None torch.set_grad_enabled(True) - if self.restarting: + self.update_restart_stage() + + if self.restarted_on_epoch_start: self.epoch_progress.reset_on_restart() + if self.resumed_on_epoch_end: + # when restarting from last without validation at end of epoch, + # self.restarting is False but it's still resuming + self.epoch_progress.increment_completed() + + if ( + self.epoch_loop.restarted_on_train_batch_end + and self.restarted_mid_epoch + and self.epoch_loop.batch_progress.is_last_batch + ): + self.epoch_progress.increment_processed() + self.epoch_progress.increment_completed() + + if ( + self.epoch_loop.restarted_on_train_batch_end + and self.epoch_loop.batch_progress.is_last_batch + and not self.restarted_mid_epoch + and not self.epoch_loop.val_loop.batch_progress.is_last_batch + ): + self.epoch_progress.increment_completed() + def on_run_start(self) -> None: """Calls the ``on_train_start`` hook.""" # update the current_epoch in-case of checkpoint reload @@ -340,12 +430,14 @@ def on_advance_start(self) -> None: for i, dl in enumerate(self._combined_loader.flattened): _set_sampler_epoch(dl, self.epoch_progress.current.processed) - self.epoch_progress.increment_ready() + if not self.restarted_mid_epoch and not self.restarted_on_epoch_end: + if not self.restarted_on_epoch_start: + self.epoch_progress.increment_ready() - call._call_callback_hooks(trainer, "on_train_epoch_start") - call._call_lightning_module_hook(trainer, "on_train_epoch_start") + call._call_callback_hooks(trainer, "on_train_epoch_start") + call._call_lightning_module_hook(trainer, "on_train_epoch_start") - self.epoch_progress.increment_started() + self.epoch_progress.increment_started() def advance(self) -> None: """Runs one whole epoch.""" @@ -379,8 +471,7 @@ def on_advance_end(self) -> None: trainer._logger_connector.on_epoch_end() - if self.epoch_loop._num_ready_batches_reached(): - # if we are restarting and the above condition holds, it's because we are reloading an epoch-end checkpoint. + if not self.restarting and self.epoch_loop._num_ready_batches_reached(): # since metric-based schedulers require access to metrics and those are not currently saved in the # checkpoint, the plateau schedulers shouldn't be updated self.epoch_loop.update_lr_schedulers("epoch", update_plateau_schedulers=not self.restarting) @@ -413,14 +504,14 @@ def teardown(self) -> None: self.epoch_loop.teardown() @override - def on_save_checkpoint(self) -> Dict: + def on_save_checkpoint(self) -> dict: state_dict = super().on_save_checkpoint() if self._combined_loader is not None and (loader_states := self._combined_loader._state_dicts()): state_dict["combined_loader"] = loader_states return state_dict @override - def on_load_checkpoint(self, state_dict: Dict) -> None: + def on_load_checkpoint(self, state_dict: dict) -> None: self._combined_loader_states_to_load = state_dict.get("combined_loader", []) super().on_load_checkpoint(state_dict) diff --git a/src/lightning/pytorch/loops/loop.py b/src/lightning/pytorch/loops/loop.py index 56d520800c447..daad309cd75d4 100644 --- a/src/lightning/pytorch/loops/loop.py +++ b/src/lightning/pytorch/loops/loop.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional +from typing import Optional import lightning.pytorch as pl from lightning.pytorch.loops.progress import _BaseProgress @@ -22,6 +22,7 @@ class _Loop: def __init__(self, trainer: "pl.Trainer") -> None: self._restarting = False + self._loaded_from_state_dict = False self.trainer = trainer @property @@ -37,7 +38,10 @@ def restarting(self, restarting: bool) -> None: if isinstance(loop, _Loop): loop.restarting = restarting - def on_save_checkpoint(self) -> Dict: + def reset_restart_stage(self) -> None: + pass + + def on_save_checkpoint(self) -> dict: """Called when saving a model checkpoint, use to persist loop state. Returns: @@ -46,10 +50,10 @@ def on_save_checkpoint(self) -> Dict: """ return {} - def on_load_checkpoint(self, state_dict: Dict) -> None: + def on_load_checkpoint(self, state_dict: dict) -> None: """Called when loading a model checkpoint, use to reload loop state.""" - def state_dict(self, destination: Optional[Dict] = None, prefix: str = "") -> Dict: + def state_dict(self, destination: Optional[dict] = None, prefix: str = "") -> dict: """The state dict is determined by the state and progress of this loop and all its children. Args: @@ -73,7 +77,7 @@ def state_dict(self, destination: Optional[Dict] = None, prefix: str = "") -> Di def load_state_dict( self, - state_dict: Dict, + state_dict: dict, prefix: str = "", ) -> None: """Loads the state of this loop and all its children.""" @@ -82,8 +86,9 @@ def load_state_dict( if isinstance(v, _Loop): v.load_state_dict(state_dict.copy(), prefix + k + ".") self.restarting = True + self._loaded_from_state_dict = True - def _load_from_state_dict(self, state_dict: Dict, prefix: str) -> None: + def _load_from_state_dict(self, state_dict: dict, prefix: str) -> None: for k, v in self.__dict__.items(): key = prefix + k if key not in state_dict: @@ -93,3 +98,8 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str) -> None: v.load_state_dict(state_dict[key]) if prefix + "state_dict" in state_dict: # compatibility with old checkpoints self.on_load_checkpoint(state_dict[prefix + "state_dict"]) + + def on_iteration_done(self) -> None: + self._restarting = False + self._loaded_from_state_dict = False + self.reset_restart_stage() diff --git a/src/lightning/pytorch/loops/optimization/automatic.py b/src/lightning/pytorch/loops/optimization/automatic.py index 2ce6acab11a37..e19b5761c4d4b 100644 --- a/src/lightning/pytorch/loops/optimization/automatic.py +++ b/src/lightning/pytorch/loops/optimization/automatic.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections import OrderedDict +from collections.abc import Mapping from dataclasses import dataclass, field from functools import partial -from typing import Any, Callable, Dict, Mapping, Optional, OrderedDict +from typing import Any, Callable, Optional import torch from torch import Tensor @@ -46,7 +48,7 @@ class ClosureResult(OutputResult): closure_loss: Optional[Tensor] loss: Optional[Tensor] = field(init=False, default=None) - extra: Dict[str, Any] = field(default_factory=dict) + extra: dict[str, Any] = field(default_factory=dict) def __post_init__(self) -> None: self._clone_loss() @@ -83,7 +85,7 @@ def from_training_step_output(cls, training_step_output: STEP_OUTPUT, normalize: return cls(closure_loss, extra=extra) @override - def asdict(self) -> Dict[str, Any]: + def asdict(self) -> dict[str, Any]: return {"loss": self.loss, **self.extra} @@ -145,7 +147,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]: return self._result.loss -_OUTPUTS_TYPE = Dict[str, Any] +_OUTPUTS_TYPE = dict[str, Any] class _AutomaticOptimization(_Loop): diff --git a/src/lightning/pytorch/loops/optimization/closure.py b/src/lightning/pytorch/loops/optimization/closure.py index 4b550166b721e..e45262a067f52 100644 --- a/src/lightning/pytorch/loops/optimization/closure.py +++ b/src/lightning/pytorch/loops/optimization/closure.py @@ -13,7 +13,7 @@ # limitations under the License. from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Dict, Generic, Optional, TypeVar +from typing import Any, Generic, Optional, TypeVar from lightning.pytorch.utilities.exceptions import MisconfigurationException @@ -22,7 +22,7 @@ @dataclass class OutputResult: - def asdict(self) -> Dict[str, Any]: + def asdict(self) -> dict[str, Any]: raise NotImplementedError diff --git a/src/lightning/pytorch/loops/optimization/manual.py b/src/lightning/pytorch/loops/optimization/manual.py index d8a4f1968c3b8..e1aabcbf42976 100644 --- a/src/lightning/pytorch/loops/optimization/manual.py +++ b/src/lightning/pytorch/loops/optimization/manual.py @@ -14,7 +14,7 @@ from collections import OrderedDict from contextlib import suppress from dataclasses import dataclass, field -from typing import Any, Dict +from typing import Any from torch import Tensor from typing_extensions import override @@ -40,7 +40,7 @@ class ManualResult(OutputResult): """ - extra: Dict[str, Any] = field(default_factory=dict) + extra: dict[str, Any] = field(default_factory=dict) @classmethod def from_training_step_output(cls, training_step_output: STEP_OUTPUT) -> "ManualResult": @@ -61,11 +61,11 @@ def from_training_step_output(cls, training_step_output: STEP_OUTPUT) -> "Manual return cls(extra=extra) @override - def asdict(self) -> Dict[str, Any]: + def asdict(self) -> dict[str, Any]: return self.extra -_OUTPUTS_TYPE = Dict[str, Any] +_OUTPUTS_TYPE = dict[str, Any] class _ManualOptimization(_Loop): diff --git a/src/lightning/pytorch/loops/prediction_loop.py b/src/lightning/pytorch/loops/prediction_loop.py index 9002e6280ffc6..dcfd873a28b4b 100644 --- a/src/lightning/pytorch/loops/prediction_loop.py +++ b/src/lightning/pytorch/loops/prediction_loop.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import OrderedDict -from typing import Any, Iterator, List, Optional, Union +from collections.abc import Iterator +from typing import Any, Optional, Union import torch from lightning_utilities import WarningCache @@ -50,17 +51,17 @@ def __init__(self, trainer: "pl.Trainer", inference_mode: bool = True) -> None: super().__init__(trainer) self.inference_mode = inference_mode # dataloaders x batches x samples. used by PredictionWriter - self.epoch_batch_indices: List[List[List[int]]] = [] - self.current_batch_indices: List[int] = [] # used by PredictionWriter + self.epoch_batch_indices: list[list[list[int]]] = [] + self.current_batch_indices: list[int] = [] # used by PredictionWriter self.batch_progress = _Progress() # across dataloaders - self.max_batches: List[Union[int, float]] = [] + self.max_batches: list[Union[int, float]] = [] self._warning_cache = WarningCache() self._data_source = _DataLoaderSource(None, "predict_dataloader") self._combined_loader: Optional[CombinedLoader] = None self._data_fetcher: Optional[_DataFetcher] = None self._results = None # for `trainer._results` access - self._predictions: List[List[Any]] = [] # dataloaders x batches + self._predictions: list[list[Any]] = [] # dataloaders x batches self._return_predictions = False self._module_mode = _ModuleMode() @@ -82,7 +83,7 @@ def return_predictions(self, return_predictions: Optional[bool] = None) -> None: self._return_predictions = return_supported if return_predictions is None else return_predictions @property - def predictions(self) -> List[Any]: + def predictions(self) -> list[Any]: """The cached predictions.""" if self._predictions == []: return self._predictions @@ -232,8 +233,9 @@ def _predict_step( self.batch_progress.increment_ready() - if not using_dataloader_iter: - any_on_epoch = self._store_data_for_prediction_writer(batch_idx, dataloader_idx) + any_on_epoch = ( + self._store_data_for_prediction_writer(batch_idx, dataloader_idx) if not using_dataloader_iter else False + ) # the `_step` methods don't take a batch_idx when `dataloader_iter` is used, but all other hooks still do, # so we need different kwargs @@ -297,7 +299,7 @@ def _build_step_args_from_hook_kwargs(self, hook_kwargs: OrderedDict, step_hook_ kwargs.pop("batch_idx", None) return tuple(kwargs.values()) - def _get_batch_indices(self, dataloader: object) -> List[List[int]]: # batches x samples + def _get_batch_indices(self, dataloader: object) -> list[list[int]]: # batches x samples """Returns a reference to the seen batch indices if the dataloader has a batch sampler wrapped by our :class:`~lightning.pytorch.overrides.distributed._IndexBatchSamplerWrapper`.""" batch_sampler = getattr(dataloader, "batch_sampler", None) diff --git a/src/lightning/pytorch/loops/progress.py b/src/lightning/pytorch/loops/progress.py index 3d34653122329..42e5de642aa32 100644 --- a/src/lightning/pytorch/loops/progress.py +++ b/src/lightning/pytorch/loops/progress.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import asdict, dataclass, field -from typing import Type from typing_extensions import override @@ -68,6 +67,10 @@ def reset_on_restart(self) -> None: """ self.ready = self.completed + def increment_by(self, n: int) -> None: + self.ready += n + self.completed += n + @dataclass class _StartedTracker(_ReadyCompletedTracker): @@ -94,6 +97,11 @@ def reset_on_restart(self) -> None: super().reset_on_restart() self.started = self.completed + @override + def increment_by(self, n: int) -> None: + super().increment_by(n) + self.started += n + @dataclass class _ProcessedTracker(_StartedTracker): @@ -121,6 +129,11 @@ def reset_on_restart(self) -> None: super().reset_on_restart() self.processed = self.completed + @override + def increment_by(self, n: int) -> None: + super().increment_by(n) + self.processed += n + @dataclass class _Progress(_BaseProgress): @@ -160,7 +173,7 @@ def increment_completed(self) -> None: self.current.completed += 1 @classmethod - def from_defaults(cls, tracker_cls: Type[_ReadyCompletedTracker], **kwargs: int) -> "_Progress": + def from_defaults(cls, tracker_cls: type[_ReadyCompletedTracker], **kwargs: int) -> "_Progress": """Utility function to easily create an instance from keyword arguments to both ``Tracker``s.""" return cls(total=tracker_cls(**kwargs), current=tracker_cls(**kwargs)) @@ -175,6 +188,10 @@ def reset_on_run(self) -> None: def reset_on_restart(self) -> None: self.current.reset_on_restart() + def increment_by(self, n: int) -> None: + self.total.increment_by(n) + self.current.increment_by(n) + @override def load_state_dict(self, state_dict: dict) -> None: self.total.load_state_dict(state_dict["total"]) @@ -206,6 +223,10 @@ def reset_on_run(self) -> None: super().reset_on_run() self.is_last_batch = False + def increment_by(self, n: int, is_last_batch: bool = False) -> None: + super().increment_by(n) + self.is_last_batch = is_last_batch + @override def load_state_dict(self, state_dict: dict) -> None: super().load_state_dict(state_dict) diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index 9e36ee65176c8..7cdf7888bbfe2 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -13,7 +13,8 @@ # limitations under the License. import math from collections import OrderedDict -from typing import Any, Dict, Optional, Union +from dataclasses import dataclass +from typing import Any, Optional, Union from typing_extensions import override @@ -37,6 +38,13 @@ _BATCH_OUTPUTS_TYPE = Optional[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]] +@dataclass +class RestartStage: + NONE = "none" + RESTARTED_ON_TRAIN_BATCH_END = "restarted_on_train_batch_end" + RESTARTED_ON_LAST = "restarted_on_last" + + class _TrainingEpochLoop(loops._Loop): """Iterates over all batches in the dataloader (one epoch) that the user returns in their :meth:`~lightning.pytorch.core.LightningModule.train_dataloader` method. @@ -81,6 +89,8 @@ def __init__(self, trainer: "pl.Trainer", min_steps: Optional[int] = None, max_s self._results = _ResultCollection(training=True) self._warning_cache = WarningCache() self._batches_that_stepped: int = 0 + self._restart_stage = RestartStage.NONE + self._skip_next_val = False @property def total_batch_idx(self) -> int: @@ -139,13 +149,63 @@ def run(self, data_fetcher: _DataFetcher) -> None: try: self.advance(data_fetcher) self.on_advance_end(data_fetcher) - self._restarting = False except StopIteration: break - self._restarting = False + finally: + self.on_iteration_done() + + @property + def restarted_on_train_batch_end(self) -> bool: + return self._restart_stage == RestartStage.RESTARTED_ON_TRAIN_BATCH_END + + @property + def restarted_on_last(self) -> bool: + return self._restart_stage == RestartStage.RESTARTED_ON_LAST + + def update_restart_stage(self) -> None: + if ( + self.restarting + and self.batch_progress.total.started == self.batch_progress.total.ready + and self.batch_progress.total.processed == self.batch_progress.total.started + and self.batch_progress.total.completed == self.batch_progress.total.processed - 1 + ): + self._restart_stage = RestartStage.RESTARTED_ON_TRAIN_BATCH_END + elif ( + self.restarting + and self.batch_progress.total.started == self.batch_progress.total.ready + and self.batch_progress.total.processed == self.batch_progress.total.started + and self.batch_progress.total.completed == self.batch_progress.total.processed + ): + self._restart_stage = RestartStage.RESTARTED_ON_LAST + else: + self._restart_stage = RestartStage.NONE + + self.val_loop.update_restart_stage() + + def reset_restart_stage(self) -> None: + self._restart_stage = RestartStage.NONE def reset(self) -> None: """Resets the internal state of the loop for a new run.""" + if ( + self.restarting + and not self._should_accumulate() + and (self.restarted_on_train_batch_end or not self.restarted_on_last) + ): + # batches_that_stepped is never set prior to saving a checkpoint, even when saving + # happens on_validation_end + # we could set it in the checkpoint but we prefer to keep checkpoints backward compatible + self._batches_that_stepped += 1 + + if self.restarted_on_train_batch_end: + self.batch_progress.increment_completed() + # handle situation in which save happened on_train_batch_end and epoch is at end + if self.batch_progress.current.completed >= self.trainer.num_training_batches: + self.batch_progress.reset_on_run() + self.scheduler_progress.reset_on_run() + self.automatic_optimization.optim_progress.reset_on_run() + self.val_loop.batch_progress.total.reset() + if self.restarting: self.batch_progress.reset_on_restart() self.scheduler_progress.reset_on_restart() @@ -197,8 +257,18 @@ def advance(self, data_fetcher: _DataFetcher) -> None: """ if self.restarting and self._should_check_val_fx(data_fetcher): - # skip training and run validation in `on_advance_end` - return + if self.val_loop.restarted_mid_evaluation: + # Go back and finish running validation + return + + if self.restarted_on_last: + # Avoid running validation again if we saved on last + self._skip_next_val = True + return + + # fast forward progress counters to end of validation + self.val_loop.increment_progress_to_evaluation_end() + # we are going to train first so the val loop does not need to restart self.val_loop.restarting = False @@ -282,6 +352,11 @@ def on_advance_end(self, data_fetcher: _DataFetcher) -> None: # VALIDATE IF NEEDED # ----------------------------------------- should_check_val = self._should_check_val_fx(data_fetcher) + + if self._skip_next_val: + should_check_val = False + self._skip_next_val = False + if should_check_val: # this needs to be set so the correct `trainer._active_loop` is picked self.trainer.validating = True @@ -315,13 +390,13 @@ def teardown(self) -> None: self.val_loop.teardown() @override - def on_save_checkpoint(self) -> Dict: + def on_save_checkpoint(self) -> dict: state_dict = super().on_save_checkpoint() state_dict["_batches_that_stepped"] = self._batches_that_stepped return state_dict @override - def on_load_checkpoint(self, state_dict: Dict) -> None: + def on_load_checkpoint(self, state_dict: dict) -> None: self._batches_that_stepped = state_dict.get("_batches_that_stepped", 0) def _accumulated_batches_reached(self) -> bool: diff --git a/src/lightning/pytorch/loops/utilities.py b/src/lightning/pytorch/loops/utilities.py index 99ea5c4254d62..2aaf877c8913d 100644 --- a/src/lightning/pytorch/loops/utilities.py +++ b/src/lightning/pytorch/loops/utilities.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -from contextlib import contextmanager -from typing import Any, Callable, ContextManager, Generator, Optional, Tuple, Type +from collections.abc import Generator +from contextlib import AbstractContextManager, contextmanager +from typing import Any, Callable, Optional import torch import torch.distributed as dist @@ -52,7 +53,7 @@ def _parse_loop_limits( min_epochs: Optional[int], max_epochs: Optional[int], trainer: "pl.Trainer", -) -> Tuple[int, int]: +) -> tuple[int, int]: """This utility computes the default values for the minimum and maximum number of steps and epochs given the values the user has selected. @@ -159,7 +160,7 @@ def _decorator(self: _Loop, *args: Any, **kwargs: Any) -> Any: raise TypeError(f"`{type(self).__name__}` needs to be a Loop.") if not hasattr(self, "inference_mode"): raise TypeError(f"`{type(self).__name__}.inference_mode` needs to be defined") - context_manager: Type[ContextManager] + context_manager: type[AbstractContextManager] if _distributed_is_initialized() and dist.get_backend() == "gloo": # gloo backend does not work properly. # https://github.com/Lightning-AI/lightning/pull/12715/files#r854569110 @@ -181,7 +182,7 @@ def _decorator(self: _Loop, *args: Any, **kwargs: Any) -> Any: def _verify_dataloader_idx_requirement( - hooks: Tuple[str, ...], is_expected: bool, stage: RunningStage, pl_module: "pl.LightningModule" + hooks: tuple[str, ...], is_expected: bool, stage: RunningStage, pl_module: "pl.LightningModule" ) -> None: for hook in hooks: fx = getattr(pl_module, hook) diff --git a/src/lightning/pytorch/overrides/distributed.py b/src/lightning/pytorch/overrides/distributed.py index e4b65285538f1..196008b7ed29f 100644 --- a/src/lightning/pytorch/overrides/distributed.py +++ b/src/lightning/pytorch/overrides/distributed.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools -from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Sized, Union, cast +from collections.abc import Iterable, Iterator, Sized +from typing import Any, Callable, Optional, Union, cast import torch from torch import Tensor @@ -27,7 +28,7 @@ def _find_tensors( obj: Union[Tensor, list, tuple, dict, Any], -) -> Union[List[Tensor], itertools.chain]: # pragma: no-cover +) -> Union[list[Tensor], itertools.chain]: # pragma: no-cover """Recursively find all tensors contained in the specified object.""" if isinstance(obj, Tensor): return [obj] @@ -201,7 +202,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: assert self.num_samples >= 1 or self.total_size == 0 @override - def __iter__(self) -> Iterator[List[int]]: + def __iter__(self) -> Iterator[list[int]]: if not isinstance(self.dataset, Sized): raise TypeError("The given dataset must implement the `__len__` method.") if self.shuffle: @@ -238,7 +239,7 @@ class _IndexBatchSamplerWrapper: def __init__(self, batch_sampler: _SizedIterable) -> None: # do not call super().__init__() on purpose - self.seen_batch_indices: List[List[int]] = [] + self.seen_batch_indices: list[list[int]] = [] self.__dict__ = { k: v @@ -246,9 +247,9 @@ def __init__(self, batch_sampler: _SizedIterable) -> None: if k not in ("__next__", "__iter__", "__len__", "__getstate__") } self._batch_sampler = batch_sampler - self._iterator: Optional[Iterator[List[int]]] = None + self._iterator: Optional[Iterator[list[int]]] = None - def __next__(self) -> List[int]: + def __next__(self) -> list[int]: assert self._iterator is not None batch = next(self._iterator) self.seen_batch_indices.append(batch) @@ -262,7 +263,7 @@ def __iter__(self) -> Self: def __len__(self) -> int: return len(self._batch_sampler) - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: state = self.__dict__.copy() state["_iterator"] = None # cannot pickle 'generator' object return state diff --git a/src/lightning/pytorch/plugins/io/wrapper.py b/src/lightning/pytorch/plugins/io/wrapper.py index 6e918b836e320..548bc1fb15cac 100644 --- a/src/lightning/pytorch/plugins/io/wrapper.py +++ b/src/lightning/pytorch/plugins/io/wrapper.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Optional from typing_extensions import override @@ -66,7 +66,7 @@ def remove_checkpoint(self, *args: Any, **kwargs: Any) -> None: self.checkpoint_io.remove_checkpoint(*args, **kwargs) @override - def load_checkpoint(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: + def load_checkpoint(self, *args: Any, **kwargs: Any) -> dict[str, Any]: """Uses the base ``checkpoint_io`` to load the checkpoint.""" assert self.checkpoint_io is not None return self.checkpoint_io.load_checkpoint(*args, **kwargs) diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py index e63ccd6912b63..75e792af46b90 100644 --- a/src/lightning/pytorch/plugins/precision/amp.py +++ b/src/lightning/pytorch/plugins/precision/amp.py @@ -9,8 +9,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Generator from contextlib import contextmanager -from typing import Any, Callable, Dict, Generator, Literal, Optional, Union +from typing import Any, Callable, Literal, Optional, Union import torch from torch import Tensor @@ -121,12 +122,12 @@ def forward_context(self) -> Generator[None, None, None]: yield @override - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: if self.scaler is not None: return self.scaler.state_dict() return {} @override - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: if self.scaler is not None: self.scaler.load_state_dict(state_dict) diff --git a/src/lightning/pytorch/plugins/precision/bitsandbytes.py b/src/lightning/pytorch/plugins/precision/bitsandbytes.py index 62acc7bf77c8d..3a2daa828bc3c 100644 --- a/src/lightning/pytorch/plugins/precision/bitsandbytes.py +++ b/src/lightning/pytorch/plugins/precision/bitsandbytes.py @@ -16,7 +16,7 @@ class BitsandbytesPrecision(Precision, FabricBNBPrecision): - """Plugin for quantizing weights with `bitsandbytes `__. + """Plugin for quantizing weights with `bitsandbytes `__. .. warning:: This is an :ref:`experimental ` feature. diff --git a/src/lightning/pytorch/plugins/precision/deepspeed.py b/src/lightning/pytorch/plugins/precision/deepspeed.py index e1e90281cf3af..9225e3bb9e7be 100644 --- a/src/lightning/pytorch/plugins/precision/deepspeed.py +++ b/src/lightning/pytorch/plugins/precision/deepspeed.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import nullcontext -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Optional, Union +from contextlib import AbstractContextManager, nullcontext +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch from lightning_utilities import apply_to_collection @@ -80,13 +80,13 @@ def convert_input(self, data: Any) -> Any: return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_dtype) @override - def tensor_init_context(self) -> ContextManager: + def tensor_init_context(self) -> AbstractContextManager: if "true" not in self.precision: return nullcontext() return _DtypeContextManager(self._desired_dtype) @override - def module_init_context(self) -> ContextManager: + def module_init_context(self) -> AbstractContextManager: return self.tensor_init_context() @override diff --git a/src/lightning/pytorch/plugins/precision/double.py b/src/lightning/pytorch/plugins/precision/double.py index 20f493bb7b2e2..efa1aa008a35e 100644 --- a/src/lightning/pytorch/plugins/precision/double.py +++ b/src/lightning/pytorch/plugins/precision/double.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import contextmanager -from typing import Any, ContextManager, Generator, Literal +from collections.abc import Generator +from contextlib import AbstractContextManager, contextmanager +from typing import Any, Literal import torch import torch.nn as nn @@ -37,11 +38,11 @@ def convert_module(self, module: nn.Module) -> nn.Module: return module.double() @override - def tensor_init_context(self) -> ContextManager: + def tensor_init_context(self) -> AbstractContextManager: return _DtypeContextManager(torch.float64) @override - def module_init_context(self) -> ContextManager: + def module_init_context(self) -> AbstractContextManager: return self.tensor_init_context() @override diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index e6c684967ed40..f3bab3e915e91 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -11,11 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Optional +from contextlib import AbstractContextManager +from typing import TYPE_CHECKING, Any, Callable, Optional import torch from lightning_utilities import apply_to_collection from torch import Tensor +from torch.nn import Module from typing_extensions import get_args, override import lightning.pytorch as pl @@ -72,6 +74,12 @@ def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradSca } self._desired_input_dtype = precision_to_type[self.precision] + @override + def convert_module(self, module: Module) -> Module: + if "true" in self.precision: + return module.to(dtype=self._desired_input_dtype) + return module + @override def clip_grad_by_norm(self, *_: Any, **__: Any) -> None: # see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_ @@ -109,15 +117,15 @@ def mixed_precision_config(self) -> "TorchMixedPrecision": ) @override - def tensor_init_context(self) -> ContextManager: + def tensor_init_context(self) -> AbstractContextManager: return _DtypeContextManager(self._desired_input_dtype) @override - def module_init_context(self) -> ContextManager: + def module_init_context(self) -> AbstractContextManager: return _DtypeContextManager(self.mixed_precision_config.param_dtype or torch.float32) @override - def forward_context(self) -> ContextManager: + def forward_context(self) -> AbstractContextManager: if "mixed" in self.precision: return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16)) return _DtypeContextManager(self._desired_input_dtype) @@ -166,12 +174,12 @@ def optimizer_step( # type: ignore[override] return closure_result @override - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: if self.scaler is not None: return self.scaler.state_dict() return {} @override - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: if self.scaler is not None: self.scaler.load_state_dict(state_dict) diff --git a/src/lightning/pytorch/plugins/precision/half.py b/src/lightning/pytorch/plugins/precision/half.py index 22dc29b580b95..fe9deb44c3653 100644 --- a/src/lightning/pytorch/plugins/precision/half.py +++ b/src/lightning/pytorch/plugins/precision/half.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import contextmanager -from typing import Any, ContextManager, Generator, Literal +from collections.abc import Generator +from contextlib import AbstractContextManager, contextmanager +from typing import Any, Literal import torch from lightning_utilities import apply_to_collection @@ -43,11 +44,11 @@ def convert_module(self, module: Module) -> Module: return module.to(dtype=self._desired_input_dtype) @override - def tensor_init_context(self) -> ContextManager: + def tensor_init_context(self) -> AbstractContextManager: return _DtypeContextManager(self._desired_input_dtype) @override - def module_init_context(self) -> ContextManager: + def module_init_context(self) -> AbstractContextManager: return self.tensor_init_context() @override diff --git a/src/lightning/pytorch/plugins/precision/precision.py b/src/lightning/pytorch/plugins/precision/precision.py index 51bdddb18f814..327fb2d4f5a27 100644 --- a/src/lightning/pytorch/plugins/precision/precision.py +++ b/src/lightning/pytorch/plugins/precision/precision.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib +from collections.abc import Generator from functools import partial -from typing import Any, Callable, Generator, List, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union import torch from torch import Tensor @@ -37,8 +38,8 @@ class Precision(FabricPrecision, CheckpointHooks): """ def connect( - self, model: Module, optimizers: List[Optimizer], lr_schedulers: List[Any] - ) -> Tuple[Module, List[Optimizer], List[Any]]: + self, model: Module, optimizers: list[Optimizer], lr_schedulers: list[Any] + ) -> tuple[Module, list[Optimizer], list[Any]]: """Connects this plugin to the accelerator and the training process.""" return model, optimizers, lr_schedulers diff --git a/src/lightning/pytorch/profilers/advanced.py b/src/lightning/pytorch/profilers/advanced.py index 467b47124eb60..41681fbd239f3 100644 --- a/src/lightning/pytorch/profilers/advanced.py +++ b/src/lightning/pytorch/profilers/advanced.py @@ -20,7 +20,7 @@ import pstats import tempfile from pathlib import Path -from typing import Dict, Optional, Tuple, Union +from typing import Optional, Union from typing_extensions import override @@ -66,7 +66,7 @@ def __init__( If you attempt to stop recording an action which was never started. """ super().__init__(dirpath=dirpath, filename=filename) - self.profiled_actions: Dict[str, cProfile.Profile] = {} + self.profiled_actions: dict[str, cProfile.Profile] = {} self.line_count_restriction = line_count_restriction self.dump_stats = dump_stats @@ -89,9 +89,10 @@ def _dump_stats(self, action_name: str, profile: cProfile.Profile) -> None: dst_fs = get_filesystem(dst_filepath) dst_fs.mkdirs(self.dirpath, exist_ok=True) # temporarily save to local since pstats can only dump into a local file - with tempfile.TemporaryDirectory( - prefix="test", suffix=str(rank_zero_only.rank), dir=os.getcwd() - ) as tmp_dir, dst_fs.open(dst_filepath, "wb") as dst_file: + with ( + tempfile.TemporaryDirectory(prefix="test", suffix=str(rank_zero_only.rank), dir=os.getcwd()) as tmp_dir, + dst_fs.open(dst_filepath, "wb") as dst_file, + ): src_filepath = os.path.join(tmp_dir, "tmp.prof") profile.dump_stats(src_filepath) src_fs = get_filesystem(src_filepath) @@ -115,7 +116,7 @@ def teardown(self, stage: Optional[str]) -> None: super().teardown(stage=stage) self.profiled_actions = {} - def __reduce__(self) -> Tuple: + def __reduce__(self) -> tuple: # avoids `TypeError: cannot pickle 'cProfile.Profile' object` return ( self.__class__, diff --git a/src/lightning/pytorch/profilers/profiler.py b/src/lightning/pytorch/profilers/profiler.py index fb448321575a9..a09b703e606b8 100644 --- a/src/lightning/pytorch/profilers/profiler.py +++ b/src/lightning/pytorch/profilers/profiler.py @@ -16,9 +16,10 @@ import logging import os from abc import ABC, abstractmethod +from collections.abc import Generator from contextlib import contextmanager from pathlib import Path -from typing import Any, Callable, Dict, Generator, Optional, TextIO, Union +from typing import Any, Callable, Optional, TextIO, Union from lightning.fabric.utilities.cloud_io import get_filesystem @@ -115,7 +116,7 @@ def describe(self) -> None: self._output_file.flush() self.teardown(stage=self._stage) - def _stats_to_str(self, stats: Dict[str, str]) -> str: + def _stats_to_str(self, stats: dict[str, str]) -> str: stage = f"{self._stage.upper()} " if self._stage is not None else "" output = [stage + "Profiler Report"] for action, value in stats.items(): diff --git a/src/lightning/pytorch/profilers/pytorch.py b/src/lightning/pytorch/profilers/pytorch.py index a26b3d321d2e0..e264d5154feba 100644 --- a/src/lightning/pytorch/profilers/pytorch.py +++ b/src/lightning/pytorch/profilers/pytorch.py @@ -16,9 +16,10 @@ import inspect import logging import os +from contextlib import AbstractContextManager from functools import lru_cache, partial from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, List, Optional, Type, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch from torch import Tensor, nn @@ -65,8 +66,8 @@ class RegisterRecordFunction: def __init__(self, model: nn.Module) -> None: self._model = model - self._records: Dict[str, record_function] = {} - self._handles: Dict[str, List[RemovableHandle]] = {} + self._records: dict[str, record_function] = {} + self._handles: dict[str, list[RemovableHandle]] = {} def _start_recording_forward(self, _: nn.Module, input: Tensor, record_name: str) -> Tensor: # Add [pl][module] in name for pytorch profiler to recognize @@ -239,7 +240,7 @@ def __init__( row_limit: int = 20, sort_by_key: Optional[str] = None, record_module_names: bool = True, - table_kwargs: Optional[Dict[str, Any]] = None, + table_kwargs: Optional[dict[str, Any]] = None, **profiler_kwargs: Any, ) -> None: r"""This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of @@ -305,8 +306,8 @@ def __init__( self.function_events: Optional[EventList] = None self._lightning_module: Optional[LightningModule] = None # set by ProfilerConnector self._register: Optional[RegisterRecordFunction] = None - self._parent_profiler: Optional[ContextManager] = None - self._recording_map: Dict[str, record_function] = {} + self._parent_profiler: Optional[AbstractContextManager] = None + self._recording_map: dict[str, record_function] = {} self._start_action_name: Optional[str] = None self._schedule: Optional[ScheduleWrapper] = None @@ -400,8 +401,8 @@ def _default_schedule() -> Optional[Callable]: return torch.profiler.schedule(wait=1, warmup=1, active=3) return None - def _default_activities(self) -> List["ProfilerActivity"]: - activities: List[ProfilerActivity] = [] + def _default_activities(self) -> list["ProfilerActivity"]: + activities: list[ProfilerActivity] = [] if not _KINETO_AVAILABLE: return activities if _TORCH_GREATER_EQUAL_2_4: @@ -530,7 +531,7 @@ def _create_profilers(self) -> None: torch.profiler.profile if _KINETO_AVAILABLE else torch.autograd.profiler.profile ) - def _create_profiler(self, profiler: Type[_PROFILER]) -> _PROFILER: + def _create_profiler(self, profiler: type[_PROFILER]) -> _PROFILER: init_parameters = inspect.signature(profiler.__init__).parameters kwargs = {k: v for k, v in self._profiler_kwargs.items() if k in init_parameters} return profiler(**kwargs) diff --git a/src/lightning/pytorch/profilers/simple.py b/src/lightning/pytorch/profilers/simple.py index eef7b12892faa..8a53965e3f487 100644 --- a/src/lightning/pytorch/profilers/simple.py +++ b/src/lightning/pytorch/profilers/simple.py @@ -18,7 +18,7 @@ import time from collections import defaultdict from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional, Union import torch from typing_extensions import override @@ -27,10 +27,10 @@ log = logging.getLogger(__name__) -_TABLE_ROW_EXTENDED = Tuple[str, float, int, float, float] -_TABLE_DATA_EXTENDED = List[_TABLE_ROW_EXTENDED] -_TABLE_ROW = Tuple[str, float, float] -_TABLE_DATA = List[_TABLE_ROW] +_TABLE_ROW_EXTENDED = tuple[str, float, int, float, float] +_TABLE_DATA_EXTENDED = list[_TABLE_ROW_EXTENDED] +_TABLE_ROW = tuple[str, float, float] +_TABLE_DATA = list[_TABLE_ROW] class SimpleProfiler(Profiler): @@ -61,8 +61,8 @@ def __init__( if you attempt to stop recording an action which was never started. """ super().__init__(dirpath=dirpath, filename=filename) - self.current_actions: Dict[str, float] = {} - self.recorded_durations: Dict = defaultdict(list) + self.current_actions: dict[str, float] = {} + self.recorded_durations: dict = defaultdict(list) self.extended = extended self.start_time = time.monotonic() @@ -81,7 +81,7 @@ def stop(self, action_name: str) -> None: duration = end_time - start_time self.recorded_durations[action_name].append(duration) - def _make_report_extended(self) -> Tuple[_TABLE_DATA_EXTENDED, float, float]: + def _make_report_extended(self) -> tuple[_TABLE_DATA_EXTENDED, float, float]: total_duration = time.monotonic() - self.start_time report = [] diff --git a/src/lightning/pytorch/profilers/xla.py b/src/lightning/pytorch/profilers/xla.py index a85f3a1295e78..3e810fbc4f096 100644 --- a/src/lightning/pytorch/profilers/xla.py +++ b/src/lightning/pytorch/profilers/xla.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Dict from typing_extensions import override @@ -45,8 +44,8 @@ def __init__(self, port: int = 9012) -> None: raise ModuleNotFoundError(str(_XLA_AVAILABLE)) super().__init__(dirpath=None, filename=None) self.port = port - self._recording_map: Dict = {} - self._step_recoding_map: Dict = {} + self._recording_map: dict = {} + self._step_recoding_map: dict = {} self._start_trace: bool = False @override diff --git a/src/lightning/pytorch/serve/servable_module.py b/src/lightning/pytorch/serve/servable_module.py index f715f4b3cad9d..ed7a8a987898b 100644 --- a/src/lightning/pytorch/serve/servable_module.py +++ b/src/lightning/pytorch/serve/servable_module.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Tuple +from typing import Any, Callable import torch from torch import Tensor @@ -56,11 +56,11 @@ def configure_response(self): """ @abstractmethod - def configure_payload(self) -> Dict[str, Any]: + def configure_payload(self) -> dict[str, Any]: """Returns a request payload as a dictionary.""" @abstractmethod - def configure_serialization(self) -> Tuple[Dict[str, Callable], Dict[str, Callable]]: + def configure_serialization(self) -> tuple[dict[str, Callable], dict[str, Callable]]: """Returns a tuple of dictionaries. The first dictionary contains the name of the ``serve_step`` input variables name as its keys @@ -72,7 +72,7 @@ def configure_serialization(self) -> Tuple[Dict[str, Callable], Dict[str, Callab """ @abstractmethod - def serve_step(self, *args: Tensor, **kwargs: Tensor) -> Dict[str, Tensor]: + def serve_step(self, *args: Tensor, **kwargs: Tensor) -> dict[str, Tensor]: r"""Returns the predictions of your model as a dictionary. .. code-block:: python @@ -90,5 +90,5 @@ def serve_step(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: """ @abstractmethod - def configure_response(self) -> Dict[str, Any]: + def configure_response(self) -> dict[str, Any]: """Returns a response to validate the server response.""" diff --git a/src/lightning/pytorch/serve/servable_module_validator.py b/src/lightning/pytorch/serve/servable_module_validator.py index 0acab203dedff..dc92625da357d 100644 --- a/src/lightning/pytorch/serve/servable_module_validator.py +++ b/src/lightning/pytorch/serve/servable_module_validator.py @@ -2,7 +2,7 @@ import logging import time from multiprocessing import Process -from typing import Any, Dict, Literal, Optional +from typing import Any, Literal, Optional import requests import torch @@ -136,7 +136,7 @@ def successful(self) -> Optional[bool]: return self.resp.status_code == 200 if self.resp else None @override - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return {"successful": self.successful, "optimization": self.optimization, "server": self.server} @staticmethod @@ -157,7 +157,7 @@ def ping() -> bool: return True @app.post("/serve") - async def serve(payload: dict = Body(...)) -> Dict[str, Any]: + async def serve(payload: dict = Body(...)) -> dict[str, Any]: body = payload["body"] for key, deserializer in deserializers.items(): diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index 9031b6ee177f3..fd3f66ef42471 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -14,7 +14,7 @@ import logging from contextlib import nullcontext from datetime import timedelta -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union import torch import torch.distributed @@ -71,7 +71,7 @@ class DDPStrategy(ParallelStrategy): def __init__( self, accelerator: Optional["pl.accelerators.Accelerator"] = None, - parallel_devices: Optional[List[torch.device]] = None, + parallel_devices: Optional[list[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[Precision] = None, @@ -133,7 +133,7 @@ def num_processes(self) -> int: @property @override - def distributed_sampler_kwargs(self) -> Dict[str, Any]: + def distributed_sampler_kwargs(self) -> dict[str, Any]: return {"num_replicas": (self.num_nodes * self.num_processes), "rank": self.global_rank} @property @@ -283,7 +283,7 @@ def configure_ddp(self) -> None: self.model = self._setup_model(self.model) self._register_ddp_hooks() - def determine_ddp_device_ids(self) -> Optional[List[int]]: + def determine_ddp_device_ids(self) -> Optional[list[int]]: if self.root_device.type == "cpu": return None return [self.root_device.index] diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index 382f8070898f8..e17377d4464b0 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -17,9 +17,11 @@ import os import platform from collections import OrderedDict +from collections.abc import Generator, Mapping from contextlib import contextmanager +from datetime import timedelta from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Mapping, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Optional, Union import torch from torch.nn import Module @@ -29,6 +31,7 @@ import lightning.pytorch as pl from lightning.fabric.plugins import ClusterEnvironment +from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout from lightning.fabric.strategies import _StrategyRegistry from lightning.fabric.strategies.deepspeed import ( _DEEPSPEED_AVAILABLE, @@ -102,9 +105,9 @@ def __init__( reduce_bucket_size: int = 200_000_000, zero_allow_untested_optimizer: bool = True, logging_batch_size_per_gpu: Union[str, int] = "auto", - config: Optional[Union[_PATH, Dict[str, Any]]] = None, + config: Optional[Union[_PATH, dict[str, Any]]] = None, logging_level: int = logging.WARN, - parallel_devices: Optional[List[torch.device]] = None, + parallel_devices: Optional[list[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, loss_scale: float = 0, initial_scale_power: int = 16, @@ -118,6 +121,7 @@ def __init__( load_full_weights: bool = False, precision_plugin: Optional[Precision] = None, process_group_backend: Optional[str] = None, + timeout: Optional[timedelta] = default_pg_timeout, ) -> None: """Provides capabilities to run training using the DeepSpeed library, with training optimizations for large billion parameter models. `For more information: https://pytorch- @@ -263,6 +267,7 @@ def __init__( precision_plugin=precision_plugin, process_group_backend=process_group_backend, ) + self._timeout: Optional[timedelta] = timeout self.config = self._load_config(config) if self.config is None: @@ -363,7 +368,9 @@ def _init_deepspeed_distributed(self) -> None: f"MEMBER: {self.global_rank + 1}/{self.world_size}" ) self._process_group_backend = self._get_process_group_backend() - deepspeed.init_distributed(self._process_group_backend, distributed_port=self.cluster_environment.main_port) + deepspeed.init_distributed( + self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout + ) def _set_node_environment_variables(self) -> None: assert self.cluster_environment is not None @@ -380,8 +387,8 @@ def restore_checkpoint_after_setup(self) -> bool: @override def _setup_model_and_optimizers( - self, model: Module, optimizers: List[Optimizer] - ) -> Tuple["deepspeed.DeepSpeedEngine", List[Optimizer]]: + self, model: Module, optimizers: list[Optimizer] + ) -> tuple["deepspeed.DeepSpeedEngine", list[Optimizer]]: """Setup a model and multiple optimizers together. Currently only a single optimizer is supported. @@ -411,10 +418,10 @@ def _setup_model_and_optimizer( model: Module, optimizer: Optional[Optimizer], lr_scheduler: Optional[Union[LRScheduler, ReduceLROnPlateau]] = None, - ) -> Tuple["deepspeed.DeepSpeedEngine", Optimizer]: + ) -> tuple["deepspeed.DeepSpeedEngine", Optimizer]: """Initialize one model and one optimizer with an optional learning rate scheduler. - This calls :func:`deepspeed.initialize` internally. + This calls ``deepspeed.initialize`` internally. """ import deepspeed @@ -452,7 +459,7 @@ def init_deepspeed(self) -> None: else: self._initialize_deepspeed_inference(self.model) - def _init_optimizers(self) -> Tuple[Optimizer, Optional[LRSchedulerConfig]]: + def _init_optimizers(self) -> tuple[Optimizer, Optional[LRSchedulerConfig]]: assert self.lightning_module is not None optimizers, lr_schedulers = _init_optimizers_and_lr_schedulers(self.lightning_module) if len(optimizers) > 1 or len(lr_schedulers) > 1: @@ -572,7 +579,7 @@ def _initialize_deepspeed_inference(self, model: Module) -> None: @property @override - def distributed_sampler_kwargs(self) -> Dict[str, int]: + def distributed_sampler_kwargs(self) -> dict[str, int]: return {"num_replicas": self.world_size, "rank": self.global_rank} @override @@ -608,7 +615,7 @@ def _multi_device(self) -> bool: return self.num_processes > 1 or self.num_nodes > 1 @override - def save_checkpoint(self, checkpoint: Dict, filepath: _PATH, storage_options: Optional[Any] = None) -> None: + def save_checkpoint(self, checkpoint: dict, filepath: _PATH, storage_options: Optional[Any] = None) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: @@ -645,7 +652,7 @@ def save_checkpoint(self, checkpoint: Dict, filepath: _PATH, storage_options: Op self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint, tag="checkpoint") @override - def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: + def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]: if self.load_full_weights and self.zero_stage_3: # Broadcast to ensure we load from the rank 0 checkpoint # This doesn't have to be the case when using deepspeed sharded checkpointing @@ -708,9 +715,9 @@ def _restore_zero_state(self, ckpt: Mapping[str, Any], strict: bool) -> None: assert self.lightning_module is not None def load(module: torch.nn.Module, prefix: str = "") -> None: - missing_keys: List[str] = [] - unexpected_keys: List[str] = [] - error_msgs: List[str] = [] + missing_keys: list[str] = [] + unexpected_keys: list[str] = [] + error_msgs: list[str] = [] state_dict = ckpt["state_dict"] # copy state_dict so _load_from_state_dict can modify it @@ -780,7 +787,7 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: offload_optimizer_device="nvme", ) - def _load_config(self, config: Optional[Union[_PATH, Dict[str, Any]]]) -> Optional[Dict[str, Any]]: + def _load_config(self, config: Optional[Union[_PATH, dict[str, Any]]]) -> Optional[dict[str, Any]]: if config is None and self.DEEPSPEED_ENV_VAR in os.environ: rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable") config = os.environ[self.DEEPSPEED_ENV_VAR] @@ -841,7 +848,7 @@ def _create_default_config( overlap_events: bool, thread_count: int, **zero_kwargs: Any, - ) -> Dict: + ) -> dict: cfg = { "activation_checkpointing": { "partition_activations": partition_activations, diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index ab6e579c3071f..bfbf99e82934c 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -13,6 +13,7 @@ # limitations under the License. import logging import shutil +from collections.abc import Generator, Mapping from contextlib import contextmanager, nullcontext from datetime import timedelta from pathlib import Path @@ -20,15 +21,8 @@ TYPE_CHECKING, Any, Callable, - Dict, - Generator, - List, Literal, - Mapping, Optional, - Set, - Tuple, - Type, Union, ) @@ -88,7 +82,7 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision, ShardingStrategy from torch.distributed.fsdp.wrap import ModuleWrapPolicy - _POLICY = Union[Set[Type[Module]], Callable[[Module, bool, int], bool], ModuleWrapPolicy] + _POLICY = Union[set[type[Module]], Callable[[Module, bool, int], bool], ModuleWrapPolicy] _SHARDING_STRATEGY = Union[ShardingStrategy, Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"]] @@ -148,12 +142,12 @@ class FSDPStrategy(ParallelStrategy): """ strategy_name = "fsdp" - _registered_strategies: List[str] = [] + _registered_strategies: list[str] = [] def __init__( self, accelerator: Optional["pl.accelerators.Accelerator"] = None, - parallel_devices: Optional[List[torch.device]] = None, + parallel_devices: Optional[list[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[Precision] = None, @@ -162,11 +156,11 @@ def __init__( cpu_offload: Union[bool, "CPUOffload", None] = None, mixed_precision: Optional["MixedPrecision"] = None, auto_wrap_policy: Optional["_POLICY"] = None, - activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None, + activation_checkpointing: Optional[Union[type[Module], list[type[Module]]]] = None, activation_checkpointing_policy: Optional["_POLICY"] = None, sharding_strategy: "_SHARDING_STRATEGY" = "FULL_SHARD", state_dict_type: Literal["full", "sharded"] = "full", - device_mesh: Optional[Union[Tuple[int], "DeviceMesh"]] = None, + device_mesh: Optional[Union[tuple[int], "DeviceMesh"]] = None, **kwargs: Any, ) -> None: super().__init__( @@ -242,7 +236,7 @@ def precision_plugin(self, precision_plugin: Optional[FSDPPrecision]) -> None: @property @override - def distributed_sampler_kwargs(self) -> Dict: + def distributed_sampler_kwargs(self) -> dict: return {"num_replicas": (self.num_nodes * self.num_processes), "rank": self.global_rank} @property @@ -455,7 +449,7 @@ def reduce( return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op) return tensor - def _determine_device_ids(self) -> List[int]: + def _determine_device_ids(self) -> list[int]: return [self.root_device.index] @override @@ -481,7 +475,7 @@ def teardown(self) -> None: self.accelerator.teardown() @classmethod - def get_registered_strategies(cls) -> List[str]: + def get_registered_strategies(cls) -> list[str]: return cls._registered_strategies @classmethod @@ -505,7 +499,7 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: cls._registered_strategies.append("fsdp_cpu_offload") @override - def lightning_module_state_dict(self) -> Dict[str, Any]: + def lightning_module_state_dict(self) -> dict[str, Any]: assert self.model is not None if self._state_dict_type == "sharded": state_dict_ctx = _get_sharded_state_dict_context(self.model) @@ -522,7 +516,7 @@ def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = Tr pass @override - def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: + def optimizer_state(self, optimizer: Optimizer) -> dict[str, Tensor]: from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import OptimStateKeyType @@ -551,7 +545,7 @@ def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: @override def save_checkpoint( - self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None + self, checkpoint: dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None ) -> None: if storage_options is not None: raise TypeError( @@ -586,7 +580,7 @@ def save_checkpoint( raise ValueError(f"Unknown state_dict_type: {self._state_dict_type}") @override - def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: + def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]: # broadcast the path from rank 0 to ensure all the states are loaded from a common path path = Path(self.broadcast(checkpoint_path)) diff --git a/src/lightning/pytorch/strategies/launchers/multiprocessing.py b/src/lightning/pytorch/strategies/launchers/multiprocessing.py index 05e3fed561ccb..aa207a527814e 100644 --- a/src/lightning/pytorch/strategies/launchers/multiprocessing.py +++ b/src/lightning/pytorch/strategies/launchers/multiprocessing.py @@ -18,7 +18,7 @@ import tempfile from contextlib import suppress from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Literal, NamedTuple, Optional, Union +from typing import Any, Callable, Literal, NamedTuple, Optional, Union import torch import torch.backends.cudnn @@ -80,7 +80,7 @@ def __init__( f"The start method '{self._start_method}' is not available on this platform. Available methods are:" f" {', '.join(mp.get_all_start_methods())}" ) - self.procs: List[mp.Process] = [] + self.procs: list[mp.Process] = [] self._already_fit = False @property @@ -224,7 +224,7 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt return _WorkerOutput(best_model_path, weights_path, trainer.state, results, extra) - def get_extra_results(self, trainer: "pl.Trainer") -> Dict[str, Any]: + def get_extra_results(self, trainer: "pl.Trainer") -> dict[str, Any]: """Gather extra state from the Trainer and return it as a dictionary for sending back to the main process. To avoid issues with memory sharing, we convert tensors to bytes. @@ -242,7 +242,7 @@ def get_extra_results(self, trainer: "pl.Trainer") -> Dict[str, Any]: # send tensors as bytes to avoid issues with memory sharing return {"callback_metrics_bytes": buffer.getvalue()} - def update_main_process_results(self, trainer: "pl.Trainer", extra: Dict[str, Any]) -> None: + def update_main_process_results(self, trainer: "pl.Trainer", extra: dict[str, Any]) -> None: """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we convert bytes back to ``torch.Tensor``. @@ -265,7 +265,7 @@ def kill(self, signum: _SIGNUM) -> None: with suppress(ProcessLookupError): os.kill(proc.pid, signum) - def __getstate__(self) -> Dict: + def __getstate__(self) -> dict: state = self.__dict__.copy() state["procs"] = [] # SpawnProcess can't be pickled return state @@ -276,7 +276,7 @@ class _WorkerOutput(NamedTuple): weights_path: Optional[_PATH] trainer_state: TrainerState trainer_results: Any - extra: Dict[str, Any] + extra: dict[str, Any] @dataclass @@ -301,7 +301,7 @@ class _GlobalStateSnapshot: use_deterministic_algorithms: bool use_deterministic_algorithms_warn_only: bool cudnn_benchmark: bool - rng_states: Dict[str, Any] + rng_states: dict[str, Any] @classmethod def capture(cls) -> "_GlobalStateSnapshot": diff --git a/src/lightning/pytorch/strategies/launchers/subprocess_script.py b/src/lightning/pytorch/strategies/launchers/subprocess_script.py index d2035d03d2589..b7ec294c148d5 100644 --- a/src/lightning/pytorch/strategies/launchers/subprocess_script.py +++ b/src/lightning/pytorch/strategies/launchers/subprocess_script.py @@ -14,7 +14,7 @@ import logging import os import subprocess -from typing import Any, Callable, List, Optional +from typing import Any, Callable, Optional from lightning_utilities.core.imports import RequirementCache from typing_extensions import override @@ -77,7 +77,7 @@ def __init__(self, cluster_environment: ClusterEnvironment, num_processes: int, self.cluster_environment = cluster_environment self.num_processes = num_processes self.num_nodes = num_nodes - self.procs: List[subprocess.Popen] = [] # launched child subprocesses, does not include the launcher + self.procs: list[subprocess.Popen] = [] # launched child subprocesses, does not include the launcher @property @override diff --git a/src/lightning/pytorch/strategies/model_parallel.py b/src/lightning/pytorch/strategies/model_parallel.py index fb45166378c78..82fec205af731 100644 --- a/src/lightning/pytorch/strategies/model_parallel.py +++ b/src/lightning/pytorch/strategies/model_parallel.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import shutil +from collections.abc import Generator, Mapping from contextlib import contextmanager, nullcontext from datetime import timedelta from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Mapping, Optional, Union +from typing import TYPE_CHECKING, Any, Literal, Optional, Union import torch from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only @@ -114,7 +115,7 @@ def num_processes(self) -> int: @property @override - def distributed_sampler_kwargs(self) -> Dict[str, Any]: + def distributed_sampler_kwargs(self) -> dict[str, Any]: assert self.device_mesh is not None data_parallel_mesh = self.device_mesh["data_parallel"] return {"num_replicas": data_parallel_mesh.size(), "rank": data_parallel_mesh.get_local_rank()} @@ -237,7 +238,7 @@ def reduce( return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op) return tensor - def _determine_device_ids(self) -> List[int]: + def _determine_device_ids(self) -> list[int]: return [self.root_device.index] @override @@ -249,7 +250,7 @@ def teardown(self) -> None: self.accelerator.teardown() @override - def lightning_module_state_dict(self) -> Dict[str, Any]: + def lightning_module_state_dict(self) -> dict[str, Any]: """Collects the state dict of the model. Only returns a non-empty state dict on rank 0 if ``save_distributed_checkpoint=False``. @@ -267,7 +268,7 @@ def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = Tr pass @override - def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Any]: + def optimizer_state(self, optimizer: Optimizer) -> dict[str, Any]: """Collects the state of the given optimizer. Only returns a non-empty state dict on rank 0 if ``save_distributed_checkpoint=False``. @@ -296,7 +297,7 @@ def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: @override def save_checkpoint( - self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None + self, checkpoint: dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None ) -> None: if storage_options is not None: raise TypeError( @@ -328,7 +329,7 @@ def save_checkpoint( return super().save_checkpoint(checkpoint=checkpoint, filepath=path) @override - def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: + def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]: # broadcast the path from rank 0 to ensure all the states are loaded from a common path path = Path(self.broadcast(checkpoint_path)) state = { diff --git a/src/lightning/pytorch/strategies/parallel.py b/src/lightning/pytorch/strategies/parallel.py index 5658438cd3f53..285d40706a5a9 100644 --- a/src/lightning/pytorch/strategies/parallel.py +++ b/src/lightning/pytorch/strategies/parallel.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod +from collections.abc import Generator from contextlib import contextmanager -from typing import Any, Dict, Generator, List, Optional +from typing import Any, Optional import torch from torch import Tensor @@ -33,7 +34,7 @@ class ParallelStrategy(Strategy, ABC): def __init__( self, accelerator: Optional["pl.accelerators.Accelerator"] = None, - parallel_devices: Optional[List[torch.device]] = None, + parallel_devices: Optional[list[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[Precision] = None, @@ -71,15 +72,15 @@ def is_global_zero(self) -> bool: return self.global_rank == 0 @property - def parallel_devices(self) -> Optional[List[torch.device]]: + def parallel_devices(self) -> Optional[list[torch.device]]: return self._parallel_devices @parallel_devices.setter - def parallel_devices(self, parallel_devices: Optional[List[torch.device]]) -> None: + def parallel_devices(self, parallel_devices: Optional[list[torch.device]]) -> None: self._parallel_devices = parallel_devices @property - def distributed_sampler_kwargs(self) -> Dict[str, Any]: + def distributed_sampler_kwargs(self) -> dict[str, Any]: return { "num_replicas": len(self.parallel_devices) if self.parallel_devices is not None else 0, "rank": self.global_rank, diff --git a/src/lightning/pytorch/strategies/strategy.py b/src/lightning/pytorch/strategies/strategy.py index 314007f497f59..0a0f52e906dd5 100644 --- a/src/lightning/pytorch/strategies/strategy.py +++ b/src/lightning/pytorch/strategies/strategy.py @@ -13,8 +13,9 @@ # limitations under the License. import logging from abc import ABC, abstractmethod +from collections.abc import Generator, Mapping from contextlib import contextmanager -from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Optional, TypeVar, Union import torch from torch import Tensor @@ -61,9 +62,9 @@ def __init__( self._model: Optional[Module] = None self._launcher: Optional[_Launcher] = None self._forward_redirection: _ForwardRedirection = _ForwardRedirection() - self._optimizers: List[Optimizer] = [] - self._lightning_optimizers: List[LightningOptimizer] = [] - self.lr_scheduler_configs: List[LRSchedulerConfig] = [] + self._optimizers: list[Optimizer] = [] + self._lightning_optimizers: list[LightningOptimizer] = [] + self.lr_scheduler_configs: list[LRSchedulerConfig] = [] @property def launcher(self) -> Optional[_Launcher]: @@ -99,11 +100,11 @@ def precision_plugin(self, precision_plugin: Optional[Precision]) -> None: self._precision_plugin = precision_plugin @property - def optimizers(self) -> List[Optimizer]: + def optimizers(self) -> list[Optimizer]: return self._optimizers @optimizers.setter - def optimizers(self, optimizers: List[Optimizer]) -> None: + def optimizers(self, optimizers: list[Optimizer]) -> None: self._optimizers = optimizers self._lightning_optimizers = [LightningOptimizer._to_lightning_optimizer(opt, self) for opt in optimizers] @@ -170,7 +171,7 @@ def setup_precision_plugin(self) -> None: self.optimizers = optimizers self.lr_scheduler_configs = lr_scheduler_configs - def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: + def optimizer_state(self, optimizer: Optimizer) -> dict[str, Tensor]: """Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom strategies. @@ -237,7 +238,7 @@ def optimizer_step( assert isinstance(model, pl.LightningModule) return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs) - def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]: + def _setup_model_and_optimizers(self, model: Module, optimizers: list[Optimizer]) -> tuple[Module, list[Optimizer]]: """Setup a model and multiple optimizers together. The returned objects are expected to be in the same order they were passed in. The default implementation will @@ -362,7 +363,7 @@ def lightning_module(self) -> Optional["pl.LightningModule"]: """Returns the pure LightningModule without potential wrappers.""" return self._lightning_module - def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: + def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]: torch.cuda.empty_cache() return self.checkpoint_io.load_checkpoint(checkpoint_path) @@ -470,13 +471,13 @@ def handles_gradient_accumulation(self) -> bool: """Whether the strategy handles gradient accumulation internally.""" return False - def lightning_module_state_dict(self) -> Dict[str, Any]: + def lightning_module_state_dict(self) -> dict[str, Any]: """Returns model state.""" assert self.lightning_module is not None return self.lightning_module.state_dict() def save_checkpoint( - self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None + self, checkpoint: dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None ) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. @@ -587,13 +588,13 @@ def _reset_optimizers_and_schedulers(self) -> None: self._lightning_optimizers = [] self.lr_scheduler_configs = [] - def __getstate__(self) -> Dict: + def __getstate__(self) -> dict: # `LightningOptimizer` overrides `self.__class__` so they cannot be pickled state = dict(vars(self)) # copy state["_lightning_optimizers"] = [] return state - def __setstate__(self, state: Dict) -> None: + def __setstate__(self, state: dict) -> None: self.__dict__ = state self.optimizers = self.optimizers # re-create the `_lightning_optimizers` diff --git a/src/lightning/pytorch/strategies/xla.py b/src/lightning/pytorch/strategies/xla.py index 56aae90c56897..faffb30d6256f 100644 --- a/src/lightning/pytorch/strategies/xla.py +++ b/src/lightning/pytorch/strategies/xla.py @@ -13,7 +13,7 @@ # limitations under the License. import io import os -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import torch from torch import Tensor @@ -49,7 +49,7 @@ class XLAStrategy(DDPStrategy): def __init__( self, accelerator: Optional["pl.accelerators.Accelerator"] = None, - parallel_devices: Optional[List[torch.device]] = None, + parallel_devices: Optional[list[torch.device]] = None, checkpoint_io: Optional[Union[XLACheckpointIO, _WrappingCheckpointIO]] = None, precision_plugin: Optional[XLAPrecision] = None, debug: bool = False, @@ -172,7 +172,7 @@ def _setup_model(self, model: Module) -> Module: # type: ignore @property @override - def distributed_sampler_kwargs(self) -> Dict[str, int]: + def distributed_sampler_kwargs(self) -> dict[str, int]: return {"num_replicas": self.world_size, "rank": self.global_rank} @override @@ -295,7 +295,7 @@ def set_world_ranks(self) -> None: @override def save_checkpoint( - self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None + self, checkpoint: dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None ) -> None: import torch_xla.core.xla_model as xm diff --git a/src/lightning/pytorch/trainer/call.py b/src/lightning/pytorch/trainer/call.py index 4c3bc5ef41bdd..012d1a2152aa3 100644 --- a/src/lightning/pytorch/trainer/call.py +++ b/src/lightning/pytorch/trainer/call.py @@ -14,7 +14,7 @@ import logging import signal from copy import deepcopy -from typing import Any, Callable, Dict, Optional, Type, Union +from typing import Any, Callable, Optional, Union from packaging.version import Version @@ -115,7 +115,11 @@ def _call_configure_model(trainer: "pl.Trainer") -> None: # we don't normally check for this before calling the hook. it is done here to avoid instantiating the context # managers if is_overridden("configure_model", trainer.lightning_module): - with trainer.strategy.tensor_init_context(), trainer.strategy.model_sharded_context(), trainer.precision_plugin.module_init_context(): # noqa: E501 + with ( + trainer.strategy.tensor_init_context(), + trainer.strategy.model_sharded_context(), + trainer.precision_plugin.module_init_context(), + ): _call_lightning_module_hook(trainer, "configure_model") @@ -222,7 +226,7 @@ def _call_callback_hooks( pl_module._current_fx_name = prev_fx_name -def _call_callbacks_state_dict(trainer: "pl.Trainer") -> Dict[str, dict]: +def _call_callbacks_state_dict(trainer: "pl.Trainer") -> dict[str, dict]: """Called when saving a model checkpoint, calls and returns every callback's `state_dict`, keyed by `Callback.state_key`.""" callback_state_dicts = {} @@ -233,7 +237,7 @@ def _call_callbacks_state_dict(trainer: "pl.Trainer") -> Dict[str, dict]: return callback_state_dicts -def _call_callbacks_on_save_checkpoint(trainer: "pl.Trainer", checkpoint: Dict[str, Any]) -> None: +def _call_callbacks_on_save_checkpoint(trainer: "pl.Trainer", checkpoint: dict[str, Any]) -> None: """Called when saving a model checkpoint, calls every callback's `on_save_checkpoint` hook.""" pl_module = trainer.lightning_module if pl_module: @@ -249,7 +253,7 @@ def _call_callbacks_on_save_checkpoint(trainer: "pl.Trainer", checkpoint: Dict[s pl_module._current_fx_name = prev_fx_name -def _call_callbacks_on_load_checkpoint(trainer: "pl.Trainer", checkpoint: Dict[str, Any]) -> None: +def _call_callbacks_on_load_checkpoint(trainer: "pl.Trainer", checkpoint: dict[str, Any]) -> None: """Called when loading a model checkpoint. Calls every callback's `on_load_checkpoint` hook. We have a dedicated function for this rather than using @@ -261,7 +265,7 @@ def _call_callbacks_on_load_checkpoint(trainer: "pl.Trainer", checkpoint: Dict[s prev_fx_name = pl_module._current_fx_name pl_module._current_fx_name = "on_load_checkpoint" - callback_states: Optional[Dict[Union[Type, str], Dict]] = checkpoint.get("callbacks") + callback_states: Optional[dict[Union[type, str], dict]] = checkpoint.get("callbacks") if callback_states is None: return @@ -285,9 +289,9 @@ def _call_callbacks_on_load_checkpoint(trainer: "pl.Trainer", checkpoint: Dict[s pl_module._current_fx_name = prev_fx_name -def _call_callbacks_load_state_dict(trainer: "pl.Trainer", checkpoint: Dict[str, Any]) -> None: +def _call_callbacks_load_state_dict(trainer: "pl.Trainer", checkpoint: dict[str, Any]) -> None: """Called when loading a model checkpoint, calls every callback's `load_state_dict`.""" - callback_states: Optional[Dict[Union[Type, str], Dict]] = checkpoint.get("callbacks") + callback_states: Optional[dict[Union[type, str], dict]] = checkpoint.get("callbacks") if callback_states is None: return diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index 06f3ee366bcaa..40ee0eef4de33 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -15,7 +15,8 @@ import logging import os from collections import Counter -from typing import Dict, List, Literal, Optional, Union +from collections.abc import Iterable +from typing import Literal, Optional, Union import torch @@ -74,11 +75,11 @@ class _AcceleratorConnector: def __init__( self, - devices: Union[List[int], str, int] = "auto", + devices: Union[list[int], str, int] = "auto", num_nodes: int = 1, accelerator: Union[str, Accelerator] = "auto", strategy: Union[str, Strategy] = "auto", - plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None, + plugins: Optional[Union[_PLUGIN_INPUT, Iterable[_PLUGIN_INPUT]]] = None, precision: Optional[_PRECISION_INPUT] = None, sync_batchnorm: bool = False, benchmark: Optional[bool] = None, @@ -123,7 +124,7 @@ def __init__( self._precision_flag: _PRECISION_INPUT_STR = "32-true" self._precision_plugin_flag: Optional[Precision] = None self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None - self._parallel_devices: List[Union[int, torch.device, str]] = [] + self._parallel_devices: list[Union[int, torch.device, str]] = [] self._layer_sync: Optional[LayerSync] = TorchSyncBatchNorm() if sync_batchnorm else None self.checkpoint_io: Optional[CheckpointIO] = None @@ -166,7 +167,7 @@ def _check_config_and_set_final_flags( strategy: Union[str, Strategy], accelerator: Union[str, Accelerator], precision: Optional[_PRECISION_INPUT], - plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]], + plugins: Optional[Union[_PLUGIN_INPUT, Iterable[_PLUGIN_INPUT]]], sync_batchnorm: bool, ) -> None: """This method checks: @@ -182,7 +183,7 @@ def _check_config_and_set_final_flags( """ if plugins is not None: - plugins = [plugins] if not isinstance(plugins, list) else plugins + plugins = [plugins] if not isinstance(plugins, Iterable) else plugins if isinstance(strategy, str): strategy = strategy.lower() @@ -225,7 +226,7 @@ def _check_config_and_set_final_flags( precision_flag = _convert_precision_to_unified_args(precision) if plugins: - plugins_flags_types: Dict[str, int] = Counter() + plugins_flags_types: dict[str, int] = Counter() for plugin in plugins: if isinstance(plugin, Precision): self._precision_plugin_flag = plugin @@ -310,7 +311,7 @@ def _check_config_and_set_final_flags( self._accelerator_flag = "cuda" self._parallel_devices = self._strategy_flag.parallel_devices - def _check_device_config_and_set_final_flags(self, devices: Union[List[int], str, int], num_nodes: int) -> None: + def _check_device_config_and_set_final_flags(self, devices: Union[list[int], str, int], num_nodes: int) -> None: if not isinstance(num_nodes, int) or num_nodes < 1: raise ValueError(f"`num_nodes` must be a positive integer, but got {num_nodes}.") diff --git a/src/lightning/pytorch/trainer/connectors/callback_connector.py b/src/lightning/pytorch/trainer/connectors/callback_connector.py index 2f2b619290ae5..a60f907d9361b 100644 --- a/src/lightning/pytorch/trainer/connectors/callback_connector.py +++ b/src/lightning/pytorch/trainer/connectors/callback_connector.py @@ -14,8 +14,9 @@ import logging import os +from collections.abc import Sequence from datetime import timedelta -from typing import Dict, List, Optional, Sequence, Union +from typing import Optional, Union import lightning.pytorch as pl from lightning.fabric.utilities.registry import _load_external_callbacks @@ -46,12 +47,12 @@ def __init__(self, trainer: "pl.Trainer"): def on_trainer_init( self, - callbacks: Optional[Union[List[Callback], Callback]], + callbacks: Optional[Union[list[Callback], Callback]], enable_checkpointing: bool, enable_progress_bar: bool, default_root_dir: Optional[str], enable_model_summary: bool, - max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None, + max_time: Optional[Union[str, timedelta, dict[str, int]]] = None, ) -> None: # init folder paths for checkpoint + weights save callbacks self.trainer._default_root_dir = default_root_dir or os.getcwd() @@ -139,7 +140,7 @@ def _configure_progress_bar(self, enable_progress_bar: bool = True) -> None: progress_bar_callback = TQDMProgressBar() self.trainer.callbacks.append(progress_bar_callback) - def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None) -> None: + def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, dict[str, int]]] = None) -> None: if max_time is None: return if any(isinstance(cb, Timer) for cb in self.trainer.callbacks): @@ -195,7 +196,7 @@ def _attach_model_callbacks(self) -> None: trainer.callbacks = all_callbacks @staticmethod - def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]: + def _reorder_callbacks(callbacks: list[Callback]) -> list[Callback]: """Moves all the tuner specific callbacks at the beginning of the list and all the `ModelCheckpoint` callbacks to the end of the list. The sequential order within the group of checkpoint callbacks is preserved, as well as the order of all other callbacks. @@ -208,9 +209,9 @@ def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]: if there were any present in the input. """ - tuner_callbacks: List[Callback] = [] - other_callbacks: List[Callback] = [] - checkpoint_callbacks: List[Callback] = [] + tuner_callbacks: list[Callback] = [] + other_callbacks: list[Callback] = [] + checkpoint_callbacks: list[Callback] = [] for cb in callbacks: if isinstance(cb, (BatchSizeFinder, LearningRateFinder)): @@ -223,7 +224,7 @@ def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]: return tuner_callbacks + other_callbacks + checkpoint_callbacks -def _validate_callbacks_list(callbacks: List[Callback]) -> None: +def _validate_callbacks_list(callbacks: list[Callback]) -> None: stateful_callbacks = [cb for cb in callbacks if is_overridden("state_dict", instance=cb)] seen_callbacks = set() for callback in stateful_callbacks: diff --git a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py index c73ceb32ec77f..71cc5a14686be 100644 --- a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py +++ b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py @@ -14,7 +14,7 @@ import logging import os import re -from typing import Any, Dict, Optional +from typing import Any, Optional import torch from fsspec.core import url_to_fs @@ -44,7 +44,7 @@ def __init__(self, trainer: "pl.Trainer") -> None: self._ckpt_path: Optional[_PATH] = None # flag to know if the user is changing the checkpoint path statefully. See `trainer.ckpt_path.setter` self._user_managed: bool = False - self._loaded_checkpoint: Dict[str, Any] = {} + self._loaded_checkpoint: dict[str, Any] = {} @property def _hpc_resume_path(self) -> Optional[str]: @@ -397,9 +397,7 @@ def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None self.resume_start(checkpoint_path) self.restore_model() self.restore_datamodule() - if self.trainer.state.fn == TrainerFn.FITTING: - # restore callback states - self.restore_callbacks() + self.restore_callbacks() def dump_checkpoint(self, weights_only: bool = False) -> dict: """Creating a model checkpoint dictionary object from various component states. @@ -493,10 +491,10 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: call._call_lightning_module_hook(trainer, "on_save_checkpoint", checkpoint) return checkpoint - def _get_lightning_module_state_dict(self) -> Dict[str, Tensor]: + def _get_lightning_module_state_dict(self) -> dict[str, Tensor]: return self.trainer.strategy.lightning_module_state_dict() - def _get_loops_state_dict(self) -> Dict[str, Any]: + def _get_loops_state_dict(self) -> dict[str, Any]: return { "fit_loop": self.trainer.fit_loop.state_dict(), "validate_loop": self.trainer.validate_loop.state_dict(), diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index 1e84a2ebd0244..3e5273085ed2b 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from collections.abc import Iterable from dataclasses import dataclass, field -from typing import Any, Iterable, Optional, Tuple, Union +from typing import Any, Optional, Union import torch.multiprocessing as mp from torch.utils.data import BatchSampler, DataLoader, RandomSampler, Sampler, SequentialSampler @@ -342,7 +343,7 @@ class _DataHookSelector: model: "pl.LightningModule" datamodule: Optional["pl.LightningDataModule"] - _valid_hooks: Tuple[str, ...] = field( + _valid_hooks: tuple[str, ...] = field( default=("on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer") ) diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py b/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py index 545749bc5b321..0dbdc4eaf76e1 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import Optional, Union from typing_extensions import TypedDict @@ -20,8 +20,8 @@ class _FxValidator: class _LogOptions(TypedDict): - allowed_on_step: Union[Tuple[bool], Tuple[bool, bool]] - allowed_on_epoch: Union[Tuple[bool], Tuple[bool, bool]] + allowed_on_step: Union[tuple[bool], tuple[bool, bool]] + allowed_on_epoch: Union[tuple[bool], tuple[bool, bool]] default_on_step: bool default_on_epoch: bool @@ -166,7 +166,7 @@ def check_logging(cls, fx_name: str) -> None: @classmethod def get_default_logging_levels( cls, fx_name: str, on_step: Optional[bool], on_epoch: Optional[bool] - ) -> Tuple[bool, bool]: + ) -> tuple[bool, bool]: """Return default logging levels for given hook.""" fx_config = cls.functions[fx_name] assert fx_config is not None @@ -191,7 +191,7 @@ def check_logging_levels(cls, fx_name: str, on_step: bool, on_epoch: bool) -> No @classmethod def check_logging_and_get_default_levels( cls, fx_name: str, on_step: Optional[bool], on_epoch: Optional[bool] - ) -> Tuple[bool, bool]: + ) -> tuple[bool, bool]: """Check if the given hook name is allowed to log and return logging levels.""" cls.check_logging(fx_name) on_step, on_epoch = cls.get_default_logging_levels(fx_name, on_step, on_epoch) diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py b/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py index c4ab11632b56b..ffc99a9772469 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Iterable, Optional, Union +from collections.abc import Iterable +from typing import Any, Optional, Union from lightning_utilities.core.apply_func import apply_to_collection from torch import Tensor diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py index 583105c3660e0..fdde19aa80eea 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Generator from dataclasses import dataclass from functools import partial, wraps -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union, cast +from typing import Any, Callable, Optional, Union, cast import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -32,8 +33,8 @@ from lightning.pytorch.utilities.warnings import PossibleUserWarning _VALUE = Union[Metric, Tensor] # Do not include scalars as they were converted to tensors -_OUT_DICT = Dict[str, Tensor] -_PBAR_DICT = Dict[str, float] +_OUT_DICT = dict[str, Tensor] +_PBAR_DICT = dict[str, float] class _METRICS(TypedDict): @@ -333,7 +334,7 @@ def __init__(self, training: bool) -> None: self.dataloader_idx: Optional[int] = None @property - def result_metrics(self) -> List[_ResultMetric]: + def result_metrics(self) -> list[_ResultMetric]: return list(self.values()) def _extract_batch_size(self, value: _ResultMetric, batch_size: Optional[int], meta: _Metadata) -> int: @@ -351,6 +352,7 @@ def _extract_batch_size(self, value: _ResultMetric, batch_size: Optional[int], m return batch_size + @torch.compiler.disable def log( self, fx: str, @@ -413,6 +415,7 @@ def log( batch_size = self._extract_batch_size(self[key], batch_size, meta) self.update_metrics(key, value, batch_size) + @torch.compiler.disable def update_metrics(self, key: str, value: _VALUE, batch_size: int) -> None: result_metric = self[key] # performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl` @@ -454,7 +457,7 @@ def valid_items(self) -> Generator: """This function is used to iterate over current valid metrics.""" return ((k, v) for k, v in self.items() if not v.has_reset and self.dataloader_idx == v.meta.dataloader_idx) - def _forked_name(self, result_metric: _ResultMetric, on_step: bool) -> Tuple[str, str]: + def _forked_name(self, result_metric: _ResultMetric, on_step: bool) -> tuple[str, str]: name = result_metric.meta.name forked_name = result_metric.meta.forked_name(on_step) add_dataloader_idx = result_metric.meta.add_dataloader_idx diff --git a/src/lightning/pytorch/trainer/connectors/signal_connector.py b/src/lightning/pytorch/trainer/connectors/signal_connector.py index 05a975326005f..e63fecd3897f2 100644 --- a/src/lightning/pytorch/trainer/connectors/signal_connector.py +++ b/src/lightning/pytorch/trainer/connectors/signal_connector.py @@ -5,7 +5,7 @@ import threading from subprocess import call from types import FrameType -from typing import Any, Callable, Dict, List, Set, Union +from typing import Any, Callable, Union import lightning.pytorch as pl from lightning.fabric.plugins.environments import SLURMEnvironment @@ -20,7 +20,7 @@ class _HandlersCompose: - def __init__(self, signal_handlers: Union[List[_HANDLER], _HANDLER]) -> None: + def __init__(self, signal_handlers: Union[list[_HANDLER], _HANDLER]) -> None: if not isinstance(signal_handlers, list): signal_handlers = [signal_handlers] self.signal_handlers = signal_handlers @@ -37,14 +37,14 @@ class _SignalConnector: def __init__(self, trainer: "pl.Trainer") -> None: self.received_sigterm = False self.trainer = trainer - self._original_handlers: Dict[_SIGNUM, _HANDLER] = {} + self._original_handlers: dict[_SIGNUM, _HANDLER] = {} def register_signal_handlers(self) -> None: self.received_sigterm = False self._original_handlers = self._get_current_signal_handlers() - sigusr_handlers: List[_HANDLER] = [] - sigterm_handlers: List[_HANDLER] = [self._sigterm_notifier_fn] + sigusr_handlers: list[_HANDLER] = [] + sigterm_handlers: list[_HANDLER] = [self._sigterm_notifier_fn] environment = self.trainer._accelerator_connector.cluster_environment if isinstance(environment, SLURMEnvironment) and environment.auto_requeue: @@ -123,7 +123,7 @@ def teardown(self) -> None: self._original_handlers = {} @staticmethod - def _get_current_signal_handlers() -> Dict[_SIGNUM, _HANDLER]: + def _get_current_signal_handlers() -> dict[_SIGNUM, _HANDLER]: """Collects the currently assigned signal handlers.""" valid_signals = _SignalConnector._valid_signals() if not _IS_WINDOWS: @@ -132,7 +132,7 @@ def _get_current_signal_handlers() -> Dict[_SIGNUM, _HANDLER]: return {signum: signal.getsignal(signum) for signum in valid_signals} @staticmethod - def _valid_signals() -> Set[signal.Signals]: + def _valid_signals() -> set[signal.Signals]: """Returns all valid signals supported on the current platform.""" return signal.valid_signals() @@ -145,7 +145,7 @@ def _register_signal(signum: _SIGNUM, handlers: _HANDLER) -> None: if threading.current_thread() is threading.main_thread(): signal.signal(signum, handlers) # type: ignore[arg-type] - def __getstate__(self) -> Dict: + def __getstate__(self) -> dict: state = self.__dict__.copy() state["_original_handlers"] = {} return state diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 406f686efe732..0509f28acb07a 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -23,9 +23,10 @@ import logging import math import os +from collections.abc import Generator, Iterable from contextlib import contextmanager from datetime import timedelta -from typing import Any, Dict, Generator, Iterable, List, Optional, Union +from typing import Any, Optional, Union from weakref import proxy import torch @@ -90,17 +91,17 @@ def __init__( *, accelerator: Union[str, Accelerator] = "auto", strategy: Union[str, Strategy] = "auto", - devices: Union[List[int], str, int] = "auto", + devices: Union[list[int], str, int] = "auto", num_nodes: int = 1, precision: Optional[_PRECISION_INPUT] = None, logger: Optional[Union[Logger, Iterable[Logger], bool]] = None, - callbacks: Optional[Union[List[Callback], Callback]] = None, + callbacks: Optional[Union[list[Callback], Callback]] = None, fast_dev_run: Union[int, bool] = False, max_epochs: Optional[int] = None, min_epochs: Optional[int] = None, max_steps: int = -1, min_steps: Optional[int] = None, - max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None, + max_time: Optional[Union[str, timedelta, dict[str, int]]] = None, limit_train_batches: Optional[Union[int, float]] = None, limit_val_batches: Optional[Union[int, float]] = None, limit_test_batches: Optional[Union[int, float]] = None, @@ -123,7 +124,7 @@ def __init__( profiler: Optional[Union[Profiler, str]] = None, detect_anomaly: bool = False, barebones: bool = False, - plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None, + plugins: Optional[Union[_PLUGIN_INPUT, list[_PLUGIN_INPUT]]] = None, sync_batchnorm: bool = False, reload_dataloaders_every_n_epochs: int = 0, default_root_dir: Optional[_PATH] = None, @@ -472,7 +473,7 @@ def __init__( setup._init_profiler(self, profiler) # init logger flags - self._loggers: List[Logger] + self._loggers: list[Logger] self._logger_connector.on_trainer_init(logger, log_every_n_steps) # init debugging flags @@ -1149,7 +1150,7 @@ def num_nodes(self) -> int: return getattr(self.strategy, "num_nodes", 1) @property - def device_ids(self) -> List[int]: + def device_ids(self) -> list[int]: """List of device indexes per node.""" devices = ( self.strategy.parallel_devices @@ -1176,15 +1177,15 @@ def lightning_module(self) -> "pl.LightningModule": return self.strategy.lightning_module # type: ignore[return-value] @property - def optimizers(self) -> List[Optimizer]: + def optimizers(self) -> list[Optimizer]: return self.strategy.optimizers @optimizers.setter - def optimizers(self, new_optims: List[Optimizer]) -> None: + def optimizers(self, new_optims: list[Optimizer]) -> None: self.strategy.optimizers = new_optims @property - def lr_scheduler_configs(self) -> List[LRSchedulerConfig]: + def lr_scheduler_configs(self) -> list[LRSchedulerConfig]: return self.strategy.lr_scheduler_configs @property @@ -1247,7 +1248,7 @@ def training_step(self, batch, batch_idx): return self.strategy.is_global_zero @property - def distributed_sampler_kwargs(self) -> Optional[Dict[str, Any]]: + def distributed_sampler_kwargs(self) -> Optional[dict[str, Any]]: if isinstance(self.strategy, ParallelStrategy): return self.strategy.distributed_sampler_kwargs return None @@ -1280,7 +1281,7 @@ def early_stopping_callback(self) -> Optional[EarlyStopping]: return callbacks[0] if len(callbacks) > 0 else None @property - def early_stopping_callbacks(self) -> List[EarlyStopping]: + def early_stopping_callbacks(self) -> list[EarlyStopping]: """A list of all instances of :class:`~lightning.pytorch.callbacks.early_stopping.EarlyStopping` found in the Trainer.callbacks list.""" return [c for c in self.callbacks if isinstance(c, EarlyStopping)] @@ -1293,7 +1294,7 @@ def checkpoint_callback(self) -> Optional[Checkpoint]: return callbacks[0] if len(callbacks) > 0 else None @property - def checkpoint_callbacks(self) -> List[Checkpoint]: + def checkpoint_callbacks(self) -> list[Checkpoint]: """A list of all instances of :class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint` found in the Trainer.callbacks list.""" return [c for c in self.callbacks if isinstance(c, Checkpoint)] @@ -1361,9 +1362,10 @@ def save_checkpoint( "Saving a checkpoint is only possible if a model is attached to the Trainer. Did you call" " `Trainer.save_checkpoint()` before calling `Trainer.{fit,validate,test,predict}`?" ) - checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only) - self.strategy.save_checkpoint(checkpoint, filepath, storage_options=storage_options) - self.strategy.barrier("Trainer.save_checkpoint") + with self.profiler.profile("save_checkpoint"): + checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only) + self.strategy.save_checkpoint(checkpoint, filepath, storage_options=storage_options) + self.strategy.barrier("Trainer.save_checkpoint") """ State properties @@ -1521,14 +1523,14 @@ def num_training_batches(self) -> Union[int, float]: return self.fit_loop.max_batches @property - def num_sanity_val_batches(self) -> List[Union[int, float]]: + def num_sanity_val_batches(self) -> list[Union[int, float]]: """The number of validation batches that will be used during the sanity-checking part of ``trainer.fit()``.""" max_batches = self.fit_loop.epoch_loop.val_loop.max_batches # re-compute the `min` in case this is called outside the sanity-checking stage return [min(self.num_sanity_val_steps, batches) for batches in max_batches] @property - def num_val_batches(self) -> List[Union[int, float]]: + def num_val_batches(self) -> list[Union[int, float]]: """The number of validation batches that will be used during ``trainer.fit()`` or ``trainer.validate()``.""" if self.state.fn == TrainerFn.VALIDATING: return self.validate_loop.max_batches @@ -1537,12 +1539,12 @@ def num_val_batches(self) -> List[Union[int, float]]: return self.fit_loop.epoch_loop.val_loop._max_batches @property - def num_test_batches(self) -> List[Union[int, float]]: + def num_test_batches(self) -> list[Union[int, float]]: """The number of test batches that will be used during ``trainer.test()``.""" return self.test_loop.max_batches @property - def num_predict_batches(self) -> List[Union[int, float]]: + def num_predict_batches(self) -> list[Union[int, float]]: """The number of prediction batches that will be used during ``trainer.predict()``.""" return self.predict_loop.max_batches @@ -1583,7 +1585,7 @@ def logger(self, logger: Optional[Logger]) -> None: self.loggers = [logger] @property - def loggers(self) -> List[Logger]: + def loggers(self) -> list[Logger]: """The list of :class:`~lightning.pytorch.loggers.logger.Logger` used. .. code-block:: python @@ -1595,7 +1597,7 @@ def loggers(self) -> List[Logger]: return self._loggers @loggers.setter - def loggers(self, loggers: Optional[List[Logger]]) -> None: + def loggers(self, loggers: Optional[list[Logger]]) -> None: self._loggers = loggers if loggers else [] @property diff --git a/src/lightning/pytorch/tuner/batch_size_scaling.py b/src/lightning/pytorch/tuner/batch_size_scaling.py index 6618f7e930ca1..99badd84bb8ad 100644 --- a/src/lightning/pytorch/tuner/batch_size_scaling.py +++ b/src/lightning/pytorch/tuner/batch_size_scaling.py @@ -15,7 +15,7 @@ import os import uuid from copy import deepcopy -from typing import Any, Dict, Optional, Tuple +from typing import Any, Optional import lightning.pytorch as pl from lightning.pytorch.utilities.memory import garbage_collection_cuda, is_oom_error @@ -98,7 +98,7 @@ def _scale_batch_size( return new_size -def __scale_batch_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]: +def __scale_batch_dump_params(trainer: "pl.Trainer") -> dict[str, Any]: dumped_params = { "loggers": trainer.loggers, "callbacks": trainer.callbacks, @@ -138,7 +138,7 @@ def __scale_batch_reset_params(trainer: "pl.Trainer", steps_per_trial: int) -> N loop.verbose = False -def __scale_batch_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) -> None: +def __scale_batch_restore_params(trainer: "pl.Trainer", params: dict[str, Any]) -> None: # TODO: There are more states that needs to be reset (#4512 and #4870) trainer.loggers = params["loggers"] trainer.callbacks = params["callbacks"] @@ -169,7 +169,7 @@ def _run_power_scaling( new_size: int, batch_arg_name: str, max_trials: int, - params: Dict[str, Any], + params: dict[str, Any], ) -> int: """Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered.""" # this flag is used to determine whether the previously scaled batch size, right before OOM, was a success or not @@ -211,7 +211,7 @@ def _run_binary_scaling( new_size: int, batch_arg_name: str, max_trials: int, - params: Dict[str, Any], + params: dict[str, Any], ) -> int: """Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is encountered. @@ -276,7 +276,7 @@ def _adjust_batch_size( factor: float = 1.0, value: Optional[int] = None, desc: Optional[str] = None, -) -> Tuple[int, bool]: +) -> tuple[int, bool]: """Helper function for adjusting the batch size. Args: @@ -328,7 +328,7 @@ def _reset_dataloaders(trainer: "pl.Trainer") -> None: loop.epoch_loop.val_loop.setup_data() -def _try_loop_run(trainer: "pl.Trainer", params: Dict[str, Any]) -> None: +def _try_loop_run(trainer: "pl.Trainer", params: dict[str, Any]) -> None: loop = trainer._active_loop assert loop is not None loop.load_state_dict(deepcopy(params["loop_state_dict"])) diff --git a/src/lightning/pytorch/tuner/lr_finder.py b/src/lightning/pytorch/tuner/lr_finder.py index d756d3d76597c..b50bedb10d53f 100644 --- a/src/lightning/pytorch/tuner/lr_finder.py +++ b/src/lightning/pytorch/tuner/lr_finder.py @@ -16,7 +16,7 @@ import os import uuid from copy import deepcopy -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import torch from lightning_utilities.core.imports import RequirementCache @@ -101,7 +101,7 @@ def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int) - self.lr_max = lr_max self.num_training = num_training - self.results: Dict[str, Any] = {} + self.results: dict[str, Any] = {} self._total_batch_idx = 0 # for debug purpose def _exchange_scheduler(self, trainer: "pl.Trainer") -> None: @@ -310,7 +310,7 @@ def _lr_find( return lr_finder -def __lr_finder_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]: +def __lr_finder_dump_params(trainer: "pl.Trainer") -> dict[str, Any]: return { "optimizers": trainer.strategy.optimizers, "lr_scheduler_configs": trainer.strategy.lr_scheduler_configs, @@ -335,7 +335,7 @@ def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_sto trainer.limit_val_batches = num_training -def __lr_finder_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) -> None: +def __lr_finder_restore_params(trainer: "pl.Trainer", params: dict[str, Any]) -> None: trainer.strategy.optimizers = params["optimizers"] trainer.strategy.lr_scheduler_configs = params["lr_scheduler_configs"] trainer.callbacks = params["callbacks"] @@ -376,8 +376,8 @@ def __init__( self.num_training = num_training self.early_stop_threshold = early_stop_threshold self.beta = beta - self.losses: List[float] = [] - self.lrs: List[float] = [] + self.losses: list[float] = [] + self.lrs: list[float] = [] self.avg_loss = 0.0 self.best_loss = 0.0 self.progress_bar_refresh_rate = progress_bar_refresh_rate @@ -463,7 +463,7 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in super().__init__(optimizer, last_epoch) @override - def get_lr(self) -> List[float]: + def get_lr(self) -> list[float]: curr_iter = self.last_epoch + 1 r = curr_iter / self.num_iter @@ -475,7 +475,7 @@ def get_lr(self) -> List[float]: return val @property - def lr(self) -> Union[float, List[float]]: + def lr(self) -> Union[float, list[float]]: return self._lr @@ -500,7 +500,7 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in super().__init__(optimizer, last_epoch) @override - def get_lr(self) -> List[float]: + def get_lr(self) -> list[float]: curr_iter = self.last_epoch + 1 r = curr_iter / self.num_iter @@ -512,11 +512,11 @@ def get_lr(self) -> List[float]: return val @property - def lr(self) -> Union[float, List[float]]: + def lr(self) -> Union[float, list[float]]: return self._lr -def _try_loop_run(trainer: "pl.Trainer", params: Dict[str, Any]) -> None: +def _try_loop_run(trainer: "pl.Trainer", params: dict[str, Any]) -> None: loop = trainer.fit_loop loop.load_state_dict(deepcopy(params["loop_state_dict"])) loop.restarting = False diff --git a/src/lightning/pytorch/utilities/_pytree.py b/src/lightning/pytorch/utilities/_pytree.py index f5f48b481c879..a0c7236cb27f1 100644 --- a/src/lightning/pytorch/utilities/_pytree.py +++ b/src/lightning/pytorch/utilities/_pytree.py @@ -1,4 +1,4 @@ -from typing import Any, List, Tuple +from typing import Any from torch.utils._pytree import SUPPORTED_NODES, LeafSpec, PyTree, TreeSpec, _get_node_type, tree_unflatten @@ -15,7 +15,7 @@ def _is_leaf_or_primitive_container(pytree: PyTree) -> bool: return all(isinstance(child, (int, float, str)) for child in child_pytrees) -def _tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]: +def _tree_flatten(pytree: PyTree) -> tuple[list[Any], TreeSpec]: """Copy of :func:`torch.utils._pytree.tree_flatten` using our custom leaf function.""" if _is_leaf_or_primitive_container(pytree): return [pytree], LeafSpec() @@ -24,8 +24,8 @@ def _tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]: flatten_fn = SUPPORTED_NODES[node_type].flatten_fn child_pytrees, context = flatten_fn(pytree) - result: List[Any] = [] - children_specs: List[TreeSpec] = [] + result: list[Any] = [] + children_specs: list[TreeSpec] = [] for child in child_pytrees: flat, child_spec = _tree_flatten(child) result += flat @@ -34,6 +34,6 @@ def _tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]: return result, TreeSpec(node_type, context, children_specs) -def _map_and_unflatten(fn: Any, values: List[Any], spec: TreeSpec) -> PyTree: +def _map_and_unflatten(fn: Any, values: list[Any], spec: TreeSpec) -> PyTree: """Utility function to apply a function and unflatten it.""" return tree_unflatten([fn(i) for i in values], spec) diff --git a/src/lightning/pytorch/utilities/argparse.py b/src/lightning/pytorch/utilities/argparse.py index eb7273b54a577..1e01297248ffa 100644 --- a/src/lightning/pytorch/utilities/argparse.py +++ b/src/lightning/pytorch/utilities/argparse.py @@ -19,12 +19,12 @@ from ast import literal_eval from contextlib import suppress from functools import wraps -from typing import Any, Callable, Type, TypeVar, cast +from typing import Any, Callable, TypeVar, cast _T = TypeVar("_T", bound=Callable[..., Any]) -def _parse_env_variables(cls: Type, template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace: +def _parse_env_variables(cls: type, template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace: """Parse environment arguments if they are defined. Examples: diff --git a/src/lightning/pytorch/utilities/combined_loader.py b/src/lightning/pytorch/utilities/combined_loader.py index 9b0ceb0288e87..9c89c998aa913 100644 --- a/src/lightning/pytorch/utilities/combined_loader.py +++ b/src/lightning/pytorch/utilities/combined_loader.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib -from collections.abc import Iterable -from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Tuple, Type, Union +from collections.abc import Iterable, Iterator +from typing import Any, Callable, Literal, Optional, Union from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter from typing_extensions import Self, TypedDict, override @@ -22,15 +22,15 @@ from lightning.fabric.utilities.types import _Stateful from lightning.pytorch.utilities._pytree import _map_and_unflatten, _tree_flatten, tree_unflatten -_ITERATOR_RETURN = Tuple[Any, int, int] # batch, batch_idx, dataloader_idx +_ITERATOR_RETURN = tuple[Any, int, int] # batch, batch_idx, dataloader_idx class _ModeIterator(Iterator[_ITERATOR_RETURN]): - def __init__(self, iterables: List[Iterable], limits: Optional[List[Union[int, float]]] = None) -> None: + def __init__(self, iterables: list[Iterable], limits: Optional[list[Union[int, float]]] = None) -> None: if limits is not None and len(limits) != len(iterables): raise ValueError(f"Mismatch in number of limits ({len(limits)}) and number of iterables ({len(iterables)})") self.iterables = iterables - self.iterators: List[Iterator] = [] + self.iterators: list[Iterator] = [] self._idx = 0 # what would be batch_idx self.limits = limits @@ -51,7 +51,7 @@ def reset(self) -> None: self.iterators = [] self._idx = 0 - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: state = self.__dict__.copy() # workaround an inconvenient `NotImplementedError`: @@ -65,9 +65,9 @@ def __getstate__(self) -> Dict[str, Any]: class _MaxSizeCycle(_ModeIterator): - def __init__(self, iterables: List[Iterable], limits: Optional[List[Union[int, float]]] = None) -> None: + def __init__(self, iterables: list[Iterable], limits: Optional[list[Union[int, float]]] = None) -> None: super().__init__(iterables, limits) - self._consumed: List[bool] = [] + self._consumed: list[bool] = [] @override def __next__(self) -> _ITERATOR_RETURN: @@ -121,7 +121,7 @@ def __len__(self) -> int: class _Sequential(_ModeIterator): - def __init__(self, iterables: List[Iterable], limits: Optional[List[Union[int, float]]] = None) -> None: + def __init__(self, iterables: list[Iterable], limits: Optional[list[Union[int, float]]] = None) -> None: super().__init__(iterables, limits) self._iterator_idx = 0 # what would be dataloader_idx @@ -206,8 +206,8 @@ def __len__(self) -> int: class _CombinationMode(TypedDict): - fn: Callable[[List[int]], int] - iterator: Type[_ModeIterator] + fn: Callable[[list[int]], int] + iterator: type[_ModeIterator] _SUPPORTED_MODES = { @@ -288,7 +288,7 @@ def __init__(self, iterables: Any, mode: _LITERAL_SUPPORTED_MODES = "min_size") self._flattened, self._spec = _tree_flatten(iterables) self._mode = mode self._iterator: Optional[_ModeIterator] = None - self._limits: Optional[List[Union[int, float]]] = None + self._limits: Optional[list[Union[int, float]]] = None @property def iterables(self) -> Any: @@ -306,12 +306,12 @@ def batch_sampler(self) -> Any: return _map_and_unflatten(lambda x: getattr(x, "batch_sampler", None), self.flattened, self._spec) @property - def flattened(self) -> List[Any]: + def flattened(self) -> list[Any]: """Return the flat list of iterables.""" return self._flattened @flattened.setter - def flattened(self, flattened: List[Any]) -> None: + def flattened(self, flattened: list[Any]) -> None: """Setter to conveniently update the list of iterables.""" if len(flattened) != len(self._flattened): raise ValueError( @@ -322,12 +322,12 @@ def flattened(self, flattened: List[Any]) -> None: self._flattened = flattened @property - def limits(self) -> Optional[List[Union[int, float]]]: + def limits(self) -> Optional[list[Union[int, float]]]: """Optional limits per iterator.""" return self._limits @limits.setter - def limits(self, limits: Optional[Union[int, float, List[Union[int, float]]]]) -> None: + def limits(self, limits: Optional[Union[int, float, list[Union[int, float]]]]) -> None: if isinstance(limits, (int, float)): limits = [limits] * len(self.flattened) elif isinstance(limits, list) and len(limits) != len(self.flattened): @@ -375,11 +375,11 @@ def _dataset_length(self) -> int: fn = _SUPPORTED_MODES[self._mode]["fn"] return fn(lengths) - def _state_dicts(self) -> List[Dict[str, Any]]: + def _state_dicts(self) -> list[dict[str, Any]]: """Returns the list of state dicts for iterables in `self.flattened` that are stateful.""" return [loader.state_dict() for loader in self.flattened if isinstance(loader, _Stateful)] - def _load_state_dicts(self, states: List[Dict[str, Any]]) -> None: + def _load_state_dicts(self, states: list[dict[str, Any]]) -> None: """Loads the state dicts for iterables in `self.flattened` that are stateful.""" if not states: return @@ -401,5 +401,5 @@ def _shutdown_workers_and_reset_iterator(dataloader: object) -> None: dataloader._iterator = None -def _get_iterables_lengths(iterables: List[Iterable]) -> List[Union[int, float]]: +def _get_iterables_lengths(iterables: list[Iterable]) -> list[Union[int, float]]: return [(float("inf") if (length := sized_len(iterable)) is None else length) for iterable in iterables] diff --git a/src/lightning/pytorch/utilities/consolidate_checkpoint.py b/src/lightning/pytorch/utilities/consolidate_checkpoint.py index 6f150bab0f23c..0dcf5879b6fc5 100644 --- a/src/lightning/pytorch/utilities/consolidate_checkpoint.py +++ b/src/lightning/pytorch/utilities/consolidate_checkpoint.py @@ -1,5 +1,5 @@ import re -from typing import Any, Dict +from typing import Any import torch @@ -7,7 +7,7 @@ from lightning.fabric.utilities.load import _load_distributed_checkpoint -def _format_checkpoint(checkpoint: Dict[str, Any]) -> Dict[str, Any]: +def _format_checkpoint(checkpoint: dict[str, Any]) -> dict[str, Any]: """Converts the special FSDP checkpoint format to the standard format the Lightning Trainer can load.""" # Rename the model key checkpoint["state_dict"] = checkpoint.pop("model") diff --git a/src/lightning/pytorch/utilities/data.py b/src/lightning/pytorch/utilities/data.py index 41c5ea86e50fb..5c14561f7aff9 100644 --- a/src/lightning/pytorch/utilities/data.py +++ b/src/lightning/pytorch/utilities/data.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +from collections.abc import Generator, Iterable, Mapping, Sized from dataclasses import fields -from typing import Any, Dict, Generator, Iterable, Mapping, Optional, Sized, Tuple, Union +from typing import Any, Optional, Union import torch from lightning_utilities.core.apply_func import is_dataclass_instance @@ -139,7 +140,7 @@ def _get_dataloader_init_args_and_kwargs( dataloader: DataLoader, sampler: Union[Sampler, Iterable], mode: Optional[RunningStage] = None, -) -> Tuple[Tuple[Any], Dict[str, Any]]: +) -> tuple[tuple[Any], dict[str, Any]]: if not isinstance(dataloader, DataLoader): raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`") @@ -233,7 +234,7 @@ def _dataloader_init_kwargs_resolve_sampler( dataloader: DataLoader, sampler: Union[Sampler, Iterable], mode: Optional[RunningStage] = None, -) -> Dict[str, Any]: +) -> dict[str, Any]: """This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its re- instantiation. @@ -349,7 +350,8 @@ def _is_dataloader_shuffled(dataloader: object) -> bool: if not hasattr(dataloader, "sampler"): # shuffling is enabled via a sampler. No sampler, no shuffling return False - sampler = dataloader.sampler + batch_sampler = dataloader.batch_sampler + sampler = batch_sampler.sampler if batch_sampler is not None else dataloader.sampler if isinstance(sampler, SequentialSampler): return False return isinstance(sampler, RandomSampler) diff --git a/src/lightning/pytorch/utilities/grads.py b/src/lightning/pytorch/utilities/grads.py index f200d892474db..08a0230f759cf 100644 --- a/src/lightning/pytorch/utilities/grads.py +++ b/src/lightning/pytorch/utilities/grads.py @@ -13,13 +13,13 @@ # limitations under the License. """Utilities to describe gradients.""" -from typing import Dict, Union +from typing import Union import torch from torch.nn import Module -def grad_norm(module: Module, norm_type: Union[float, int, str], group_separator: str = "/") -> Dict[str, float]: +def grad_norm(module: Module, norm_type: Union[float, int, str], group_separator: str = "/") -> dict[str, float]: """Compute each parameter's gradient's norm and their overall norm. The overall norm is computed over all gradients together, as if they diff --git a/src/lightning/pytorch/utilities/migration/migration.py b/src/lightning/pytorch/utilities/migration/migration.py index 6a5a914bed9ba..5db942b29183f 100644 --- a/src/lightning/pytorch/utilities/migration/migration.py +++ b/src/lightning/pytorch/utilities/migration/migration.py @@ -31,17 +31,17 @@ """ import re -from typing import Any, Callable, Dict, List +from typing import Any, Callable from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch.callbacks.early_stopping import EarlyStopping from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint from lightning.pytorch.utilities.rank_zero import rank_zero_warn -_CHECKPOINT = Dict[str, Any] +_CHECKPOINT = dict[str, Any] -def _migration_index() -> Dict[str, List[Callable[[_CHECKPOINT], _CHECKPOINT]]]: +def _migration_index() -> dict[str, list[Callable[[_CHECKPOINT], _CHECKPOINT]]]: """Migration functions returned here will get executed in the order they are listed.""" return { "0.10.0": [_migrate_model_checkpoint_early_stopping], @@ -133,7 +133,7 @@ def _migrate_loop_batches_that_stepped(checkpoint: _CHECKPOINT) -> _CHECKPOINT: return checkpoint -def _get_fit_loop_initial_state_1_6_0() -> Dict: +def _get_fit_loop_initial_state_1_6_0() -> dict: return { "epoch_loop.batch_loop.manual_loop.optim_step_progress": { "current": {"completed": 0, "ready": 0}, diff --git a/src/lightning/pytorch/utilities/migration/utils.py b/src/lightning/pytorch/utilities/migration/utils.py index 1537c2684fbe4..2c5656e1f1016 100644 --- a/src/lightning/pytorch/utilities/migration/utils.py +++ b/src/lightning/pytorch/utilities/migration/utils.py @@ -18,7 +18,7 @@ import threading import warnings from types import ModuleType, TracebackType -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Optional from packaging.version import Version from typing_extensions import override @@ -32,13 +32,13 @@ from lightning.pytorch.utilities.rank_zero import rank_zero_warn _log = logging.getLogger(__name__) -_CHECKPOINT = Dict[str, Any] +_CHECKPOINT = dict[str, Any] _lock = threading.Lock() def migrate_checkpoint( checkpoint: _CHECKPOINT, target_version: Optional[str] = None -) -> Tuple[_CHECKPOINT, Dict[str, List[str]]]: +) -> tuple[_CHECKPOINT, dict[str, list[str]]]: """Applies Lightning version migrations to a checkpoint dictionary. Args: @@ -121,7 +121,7 @@ class _FaultTolerantMode(LightningEnum): def __exit__( self, - exc_type: Optional[Type[BaseException]], + exc_type: Optional[type[BaseException]], exc_value: Optional[BaseException], exc_traceback: Optional[TracebackType], ) -> None: diff --git a/src/lightning/pytorch/utilities/model_helpers.py b/src/lightning/pytorch/utilities/model_helpers.py index 36adedf4a831b..44591aa7f4dc1 100644 --- a/src/lightning/pytorch/utilities/model_helpers.py +++ b/src/lightning/pytorch/utilities/model_helpers.py @@ -15,7 +15,7 @@ import inspect import logging import os -from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, Type, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar from lightning_utilities.core.imports import RequirementCache from torch import nn @@ -26,7 +26,7 @@ _log = logging.getLogger(__name__) -def is_overridden(method_name: str, instance: Optional[object] = None, parent: Optional[Type[object]] = None) -> bool: +def is_overridden(method_name: str, instance: Optional[object] = None, parent: Optional[type[object]] = None) -> bool: if instance is None: # if `self.lightning_module` was passed as instance, it can be `None` return False @@ -65,7 +65,7 @@ class _ModuleMode: """Captures the ``nn.Module.training`` (bool) mode of every submodule, and allows it to be restored later on.""" def __init__(self) -> None: - self.mode: Dict[str, bool] = {} + self.mode: dict[str, bool] = {} def capture(self, module: nn.Module) -> None: self.mode.clear() @@ -108,10 +108,10 @@ class _restricted_classmethod_impl(Generic[_T, _P, _R_co]): """Drop-in replacement for @classmethod, but raises an exception when the decorated method is called on an instance instead of a class type.""" - def __init__(self, method: Callable[Concatenate[Type[_T], _P], _R_co]) -> None: + def __init__(self, method: Callable[Concatenate[type[_T], _P], _R_co]) -> None: self.method = method - def __get__(self, instance: Optional[_T], cls: Type[_T]) -> Callable[_P, _R_co]: + def __get__(self, instance: Optional[_T], cls: type[_T]) -> Callable[_P, _R_co]: # The wrapper ensures that the method can be inspected, but not called on an instance @functools.wraps(self.method) def wrapper(*args: Any, **kwargs: Any) -> _R_co: diff --git a/src/lightning/pytorch/utilities/model_summary/model_summary.py b/src/lightning/pytorch/utilities/model_summary/model_summary.py index c40dc94568a51..6a5baf2c1e04a 100644 --- a/src/lightning/pytorch/utilities/model_summary/model_summary.py +++ b/src/lightning/pytorch/utilities/model_summary/model_summary.py @@ -17,7 +17,7 @@ import logging import math from collections import OrderedDict -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional, Union import torch import torch.nn as nn @@ -73,8 +73,8 @@ def __init__(self, module: nn.Module) -> None: super().__init__() self._module = module self._hook_handle = self._register_hook() - self._in_size: Optional[Union[str, List]] = None - self._out_size: Optional[Union[str, List]] = None + self._in_size: Optional[Union[str, list]] = None + self._out_size: Optional[Union[str, list]] = None def __del__(self) -> None: self.detach_hook() @@ -121,11 +121,11 @@ def detach_hook(self) -> None: self._hook_handle.remove() @property - def in_size(self) -> Union[str, List]: + def in_size(self) -> Union[str, list]: return self._in_size or UNKNOWN_SIZE @property - def out_size(self) -> Union[str, List]: + def out_size(self) -> Union[str, list]: return self._out_size or UNKNOWN_SIZE @property @@ -221,8 +221,8 @@ def __init__(self, model: "pl.LightningModule", max_depth: int = 1) -> None: self._precision_megabytes = (precision / 8.0) * 1e-6 @property - def named_modules(self) -> List[Tuple[str, nn.Module]]: - mods: List[Tuple[str, nn.Module]] + def named_modules(self) -> list[tuple[str, nn.Module]]: + mods: list[tuple[str, nn.Module]] if self._max_depth == 0: mods = [] elif self._max_depth == 1: @@ -234,31 +234,31 @@ def named_modules(self) -> List[Tuple[str, nn.Module]]: return mods @property - def layer_names(self) -> List[str]: + def layer_names(self) -> list[str]: return list(self._layer_summary.keys()) @property - def layer_types(self) -> List[str]: + def layer_types(self) -> list[str]: return [layer.layer_type for layer in self._layer_summary.values()] @property - def in_sizes(self) -> List: + def in_sizes(self) -> list: return [layer.in_size for layer in self._layer_summary.values()] @property - def out_sizes(self) -> List: + def out_sizes(self) -> list: return [layer.out_size for layer in self._layer_summary.values()] @property - def param_nums(self) -> List[int]: + def param_nums(self) -> list[int]: return [layer.num_parameters for layer in self._layer_summary.values()] @property - def training_modes(self) -> List[bool]: + def training_modes(self) -> list[bool]: return [layer.training for layer in self._layer_summary.values()] @property - def total_training_modes(self) -> Dict[str, int]: + def total_training_modes(self) -> dict[str, int]: modes = [layer.training for layer in self._model.modules()] modes = modes[1:] # exclude the root module return {"train": modes.count(True), "eval": modes.count(False)} @@ -279,7 +279,7 @@ def total_layer_params(self) -> int: def model_size(self) -> float: return self.total_parameters * self._precision_megabytes - def summarize(self) -> Dict[str, LayerSummary]: + def summarize(self) -> dict[str, LayerSummary]: summary = OrderedDict((name, LayerSummary(module)) for name, module in self.named_modules) if self._model.example_input_array is not None: self._forward_example_input() @@ -318,7 +318,7 @@ def _forward_example_input(self) -> None: model(input_) mode.restore(model) - def _get_summary_data(self) -> List[Tuple[str, List[str]]]: + def _get_summary_data(self) -> list[tuple[str, list[str]]]: """Makes a summary listing with: Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes, Model Size @@ -341,7 +341,7 @@ def _get_summary_data(self) -> List[Tuple[str, List[str]]]: return arrays - def _add_leftover_params_to_summary(self, arrays: List[Tuple[str, List[str]]], total_leftover_params: int) -> None: + def _add_leftover_params_to_summary(self, arrays: list[tuple[str, list[str]]], total_leftover_params: int) -> None: """Add summary of params not associated with module or layer to model summary.""" layer_summaries = dict(arrays) layer_summaries[" "].append(" ") @@ -368,7 +368,7 @@ def __repr__(self) -> str: return str(self) -def parse_batch_shape(batch: Any) -> Union[str, List]: +def parse_batch_shape(batch: Any) -> Union[str, list]: if hasattr(batch, "shape"): return list(batch.shape) @@ -382,8 +382,8 @@ def _format_summary_table( total_parameters: int, trainable_parameters: int, model_size: float, - total_training_modes: Dict[str, int], - *cols: Tuple[str, List[str]], + total_training_modes: dict[str, int], + *cols: tuple[str, list[str]], ) -> str: """Takes in a number of arrays, each specifying a column in the summary table, and combines them all into one big string defining the summary table that are nicely formatted.""" diff --git a/src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py b/src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py index 57d9ae5024b58..5038aebf0db79 100644 --- a/src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py +++ b/src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py @@ -14,7 +14,6 @@ """Utilities that can be used with Deepspeed.""" from collections import OrderedDict -from typing import Dict, List, Tuple import torch from lightning_utilities.core.imports import RequirementCache @@ -54,7 +53,7 @@ def partitioned_size(p: Parameter) -> int: class DeepSpeedSummary(ModelSummary): @override - def summarize(self) -> Dict[str, DeepSpeedLayerSummary]: # type: ignore[override] + def summarize(self) -> dict[str, DeepSpeedLayerSummary]: # type: ignore[override] summary = OrderedDict((name, DeepSpeedLayerSummary(module)) for name, module in self.named_modules) if self._model.example_input_array is not None: self._forward_example_input() @@ -83,11 +82,11 @@ def trainable_parameters(self) -> int: ) @property - def parameters_per_layer(self) -> List[int]: + def parameters_per_layer(self) -> list[int]: return [layer.average_shard_parameters for layer in self._layer_summary.values()] @override - def _get_summary_data(self) -> List[Tuple[str, List[str]]]: + def _get_summary_data(self) -> list[tuple[str, list[str]]]: """Makes a summary listing with: Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes, Model Size @@ -112,7 +111,7 @@ def _get_summary_data(self) -> List[Tuple[str, List[str]]]: return arrays @override - def _add_leftover_params_to_summary(self, arrays: List[Tuple[str, List[str]]], total_leftover_params: int) -> None: + def _add_leftover_params_to_summary(self, arrays: list[tuple[str, list[str]]], total_leftover_params: int) -> None: """Add summary of params not associated with module or layer to model summary.""" super()._add_leftover_params_to_summary(arrays, total_leftover_params) layer_summaries = dict(arrays) diff --git a/src/lightning/pytorch/utilities/parameter_tying.py b/src/lightning/pytorch/utilities/parameter_tying.py index 8680285c272e8..da0309b0626bb 100644 --- a/src/lightning/pytorch/utilities/parameter_tying.py +++ b/src/lightning/pytorch/utilities/parameter_tying.py @@ -18,17 +18,17 @@ """ -from typing import Dict, List, Optional +from typing import Optional from torch import nn -def find_shared_parameters(module: nn.Module) -> List[str]: +def find_shared_parameters(module: nn.Module) -> list[str]: """Returns a list of names of shared parameters set in the module.""" return _find_shared_parameters(module) -def _find_shared_parameters(module: nn.Module, tied_parameters: Optional[Dict] = None, prefix: str = "") -> List[str]: +def _find_shared_parameters(module: nn.Module, tied_parameters: Optional[dict] = None, prefix: str = "") -> list[str]: if tied_parameters is None: tied_parameters = {} for name, param in module._parameters.items(): diff --git a/src/lightning/pytorch/utilities/parsing.py b/src/lightning/pytorch/utilities/parsing.py index 0f4460a3d5144..16eef555291bd 100644 --- a/src/lightning/pytorch/utilities/parsing.py +++ b/src/lightning/pytorch/utilities/parsing.py @@ -17,8 +17,9 @@ import inspect import pickle import types +from collections.abc import MutableMapping, Sequence from dataclasses import fields, is_dataclass -from typing import Any, Dict, List, Literal, MutableMapping, Optional, Sequence, Tuple, Type, Union +from typing import Any, Literal, Optional, Union from torch import nn @@ -48,7 +49,7 @@ def clean_namespace(hparams: MutableMapping) -> None: del hparams[k] -def parse_class_init_keys(cls: Type) -> Tuple[str, Optional[str], Optional[str]]: +def parse_class_init_keys(cls: type) -> tuple[str, Optional[str], Optional[str]]: """Parse key words for standard ``self``, ``*args`` and ``**kwargs``. Examples: @@ -60,7 +61,7 @@ def parse_class_init_keys(cls: Type) -> Tuple[str, Optional[str], Optional[str]] ('self', 'my_args', 'my_kwargs') """ - init_parameters = inspect.signature(cls.__init__).parameters + init_parameters = inspect.signature(cls.__init__).parameters # type: ignore[misc] # docs claims the params are always ordered # https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters init_params = list(init_parameters.values()) @@ -68,7 +69,7 @@ def parse_class_init_keys(cls: Type) -> Tuple[str, Optional[str], Optional[str]] n_self = init_params[0].name def _get_first_if_any( - params: List[inspect.Parameter], + params: list[inspect.Parameter], param_type: Literal[inspect._ParameterKind.VAR_POSITIONAL, inspect._ParameterKind.VAR_KEYWORD], ) -> Optional[str]: for p in params: @@ -82,13 +83,13 @@ def _get_first_if_any( return n_self, n_args, n_kwargs -def get_init_args(frame: types.FrameType) -> Dict[str, Any]: # pragma: no-cover +def get_init_args(frame: types.FrameType) -> dict[str, Any]: # pragma: no-cover """For backwards compatibility: #16369.""" _, local_args = _get_init_args(frame) return local_args -def _get_init_args(frame: types.FrameType) -> Tuple[Optional[Any], Dict[str, Any]]: +def _get_init_args(frame: types.FrameType) -> tuple[Optional[Any], dict[str, Any]]: _, _, _, local_vars = inspect.getargvalues(frame) if "__class__" not in local_vars: return None, {} @@ -109,10 +110,10 @@ def _get_init_args(frame: types.FrameType) -> Tuple[Optional[Any], Dict[str, Any def collect_init_args( frame: types.FrameType, - path_args: List[Dict[str, Any]], + path_args: list[dict[str, Any]], inside: bool = False, - classes: Tuple[Type, ...] = (), -) -> List[Dict[str, Any]]: + classes: tuple[type, ...] = (), +) -> list[dict[str, Any]]: """Recursively collects the arguments passed to the child constructors in the inheritance tree. Args: @@ -147,7 +148,7 @@ def save_hyperparameters( *args: Any, ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None, - given_hparams: Optional[Dict[str, Any]] = None, + given_hparams: Optional[dict[str, Any]] = None, ) -> None: """See :meth:`~lightning.pytorch.LightningModule.save_hyperparameters`""" @@ -232,14 +233,14 @@ class AttributeDict(_AttributeDict): """ -def _lightning_get_all_attr_holders(model: "pl.LightningModule", attribute: str) -> List[Any]: +def _lightning_get_all_attr_holders(model: "pl.LightningModule", attribute: str) -> list[Any]: """Special attribute finding for Lightning. Gets all of the objects or dicts that holds attribute. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. """ - holders: List[Any] = [] + holders: list[Any] = [] # Check if attribute in model if hasattr(model, attribute): diff --git a/src/lightning/pytorch/utilities/seed.py b/src/lightning/pytorch/utilities/seed.py index 4ba9e7f0f960f..7250ba59366c2 100644 --- a/src/lightning/pytorch/utilities/seed.py +++ b/src/lightning/pytorch/utilities/seed.py @@ -13,8 +13,8 @@ # limitations under the License. """Utilities to help with reproducibility of models.""" +from collections.abc import Generator from contextlib import contextmanager -from typing import Generator from lightning.fabric.utilities.seed import _collect_rng_states, _set_rng_states diff --git a/src/lightning/pytorch/utilities/testing/_runif.py b/src/lightning/pytorch/utilities/testing/_runif.py index 03b3afd61b875..9c46913681143 100644 --- a/src/lightning/pytorch/utilities/testing/_runif.py +++ b/src/lightning/pytorch/utilities/testing/_runif.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional, Tuple +from typing import Optional from lightning_utilities.core.imports import RequirementCache @@ -42,7 +42,7 @@ def _runif_reasons( psutil: bool = False, sklearn: bool = False, onnx: bool = False, -) -> Tuple[List[str], Dict[str, bool]]: +) -> tuple[list[str], dict[str, bool]]: """Construct reasons for pytest skipif. Args: diff --git a/src/lightning/pytorch/utilities/types.py b/src/lightning/pytorch/utilities/types.py index c1b971e924a52..3fd80d5e07a9a 100644 --- a/src/lightning/pytorch/utilities/types.py +++ b/src/lightning/pytorch/utilities/types.py @@ -17,19 +17,13 @@ - Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no leading `_`) """ +from collections.abc import Generator, Iterator, Mapping, Sequence from contextlib import contextmanager from dataclasses import dataclass from typing import ( Any, - Generator, - Iterator, - List, - Mapping, Optional, Protocol, - Sequence, - Tuple, - Type, TypedDict, Union, runtime_checkable, @@ -47,8 +41,8 @@ _NUMBER = Union[int, float] _METRIC = Union[Metric, Tensor, _NUMBER] STEP_OUTPUT = Optional[Union[Tensor, Mapping[str, Any]]] -_EVALUATE_OUTPUT = List[Mapping[str, float]] # 1 dict per DataLoader -_PREDICT_OUTPUT = Union[List[Any], List[List[Any]]] +_EVALUATE_OUTPUT = list[Mapping[str, float]] # 1 dict per DataLoader +_PREDICT_OUTPUT = Union[list[Any], list[list[Any]]] TRAIN_DATALOADERS = Any # any iterable or collection of iterables EVAL_DATALOADERS = Any # any iterable or collection of iterables @@ -60,7 +54,7 @@ class DistributedDataParallel(Protocol): def __init__( self, module: torch.nn.Module, - device_ids: Optional[List[Union[int, torch.device]]] = None, + device_ids: Optional[list[Union[int, torch.device]]] = None, output_device: Optional[Union[int, torch.device]] = None, dim: int = 0, broadcast_buffers: bool = True, @@ -79,7 +73,7 @@ def no_sync(self) -> Generator: ... # todo: improve LRSchedulerType naming/typing LRSchedulerTypeTuple = (LRScheduler, ReduceLROnPlateau) LRSchedulerTypeUnion = Union[LRScheduler, ReduceLROnPlateau] -LRSchedulerType = Union[Type[LRScheduler], Type[ReduceLROnPlateau]] +LRSchedulerType = Union[type[LRScheduler], type[ReduceLROnPlateau]] LRSchedulerPLType = Union[LRScheduler, ReduceLROnPlateau] @@ -110,18 +104,25 @@ class LRSchedulerConfigType(TypedDict, total=False): strict: bool -class OptimizerLRSchedulerConfig(TypedDict): +class OptimizerConfigType(TypedDict): optimizer: Optimizer - lr_scheduler: NotRequired[Union[LRSchedulerTypeUnion, LRSchedulerConfigType]] + + +class OptimizerLRSchedulerConfigType(TypedDict): + optimizer: Optimizer + lr_scheduler: Union[LRSchedulerTypeUnion, LRSchedulerConfigType] + monitor: NotRequired[str] OptimizerLRScheduler = Optional[ Union[ Optimizer, Sequence[Optimizer], - Tuple[Sequence[Optimizer], Sequence[Union[LRSchedulerTypeUnion, LRSchedulerConfig]]], - OptimizerLRSchedulerConfig, - Sequence[OptimizerLRSchedulerConfig], + tuple[Sequence[Optimizer], Sequence[Union[LRSchedulerTypeUnion, LRSchedulerConfig]]], + OptimizerConfigType, + OptimizerLRSchedulerConfigType, + Sequence[OptimizerConfigType], + Sequence[OptimizerLRSchedulerConfigType], ] ] diff --git a/src/lightning/pytorch/utilities/upgrade_checkpoint.py b/src/lightning/pytorch/utilities/upgrade_checkpoint.py index 87ad6031f9f24..04cf000283d77 100644 --- a/src/lightning/pytorch/utilities/upgrade_checkpoint.py +++ b/src/lightning/pytorch/utilities/upgrade_checkpoint.py @@ -16,7 +16,6 @@ from argparse import ArgumentParser, Namespace from pathlib import Path from shutil import copyfile -from typing import List import torch from tqdm import tqdm @@ -29,7 +28,7 @@ def _upgrade(args: Namespace) -> None: path = Path(args.path).absolute() extension: str = args.extension if args.extension.startswith(".") else f".{args.extension}" - files: List[Path] = [] + files: list[Path] = [] if not path.exists(): _log.error( diff --git a/src/lightning_fabric/__setup__.py b/src/lightning_fabric/__setup__.py index 8fe0bc0937ef5..a55e1f2332f37 100644 --- a/src/lightning_fabric/__setup__.py +++ b/src/lightning_fabric/__setup__.py @@ -3,7 +3,7 @@ from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path from types import ModuleType -from typing import Any, Dict +from typing import Any from pkg_resources import parse_requirements from setuptools import find_packages @@ -29,7 +29,7 @@ def _load_assistant() -> ModuleType: return _load_py_module("assistant", location) -def _prepare_extras() -> Dict[str, Any]: +def _prepare_extras() -> dict[str, Any]: assistant = _load_assistant() # https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-extras # Define package extras. These are only installed if you specify them. @@ -49,7 +49,7 @@ def _prepare_extras() -> Dict[str, Any]: return extras -def _setup_args() -> Dict[str, Any]: +def _setup_args() -> dict[str, Any]: assistant = _load_assistant() about = _load_py_module("about", os.path.join(_PACKAGE_ROOT, "__about__.py")) version = _load_py_module("version", os.path.join(_PACKAGE_ROOT, "__version__.py")) @@ -73,7 +73,7 @@ def _setup_args() -> Dict[str, Any]: "include_package_data": True, "zip_safe": False, "keywords": ["deep learning", "pytorch", "AI"], - "python_requires": ">=3.8", + "python_requires": ">=3.9", "setup_requires": ["wheel"], "install_requires": assistant.load_requirements( _PATH_REQUIREMENTS, unfreeze="none" if _FREEZE_REQUIREMENTS else "all" @@ -105,7 +105,6 @@ def _setup_args() -> Dict[str, Any]: # Specify the Python versions you support here. In particular, ensure # that you indicate whether you support Python 2, Python 3 or both. "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", diff --git a/src/pytorch_lightning/README.md b/src/pytorch_lightning/README.md index a9200baa273dd..ae9339dfb2b0d 100644 --- a/src/pytorch_lightning/README.md +++ b/src/pytorch_lightning/README.md @@ -80,7 +80,6 @@ Lightning is rigorously tested across multiple CPUs, GPUs and TPUs and against m | System / PyTorch ver. | 1.12 | 1.13 | 2.0 | 2.1 | | :--------------------------------: | :---------------------------------------------------------------------------------------------------------: | ----------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------- | | Linux py3.9 \[GPUs\] | | | | ![Build Status](https://dev.azure.com/Lightning-AI/lightning/_apis/build/status%2Fpytorch-lightning%20%28GPUs%29) | -| Linux py3.9 \[TPUs\] | | | ![Test PyTorch - TPU](https://github.com/Lightning-AI/lightning/actions/workflows/tpu-tests.yml/badge.svg) | | | Linux (multiple Python versions) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | | OSX (multiple Python versions) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | | Windows (multiple Python versions) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | diff --git a/src/pytorch_lightning/__setup__.py b/src/pytorch_lightning/__setup__.py index 7eedace6cac93..6677b469ba1de 100644 --- a/src/pytorch_lightning/__setup__.py +++ b/src/pytorch_lightning/__setup__.py @@ -3,7 +3,7 @@ from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path from types import ModuleType -from typing import Any, Dict +from typing import Any from pkg_resources import parse_requirements from setuptools import find_packages @@ -29,7 +29,7 @@ def _load_assistant() -> ModuleType: return _load_py_module("assistant", location) -def _prepare_extras() -> Dict[str, Any]: +def _prepare_extras() -> dict[str, Any]: assistant = _load_assistant() # https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-extras # Define package extras. These are only installed if you specify them. @@ -49,7 +49,7 @@ def _prepare_extras() -> Dict[str, Any]: return extras -def _setup_args() -> Dict[str, Any]: +def _setup_args() -> dict[str, Any]: assistant = _load_assistant() about = _load_py_module("about", os.path.join(_PACKAGE_ROOT, "__about__.py")) version = _load_py_module("version", os.path.join(_PACKAGE_ROOT, "__version__.py")) @@ -80,7 +80,7 @@ def _setup_args() -> Dict[str, Any]: "long_description_content_type": "text/markdown", "zip_safe": False, "keywords": ["deep learning", "pytorch", "AI"], - "python_requires": ">=3.8", + "python_requires": ">=3.9", "setup_requires": ["wheel"], # TODO: aggregate pytorch and lite requirements as we include its source code directly in this package. # this is not a problem yet because lite's base requirements are all included in pytorch's base requirements @@ -107,7 +107,6 @@ def _setup_args() -> Dict[str, Any]: "Operating System :: OS Independent", # Specify the Python versions you support here. "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", diff --git a/src/version.info b/src/version.info index 197c4d5c2d7c7..fe5c508359b11 100644 --- a/src/version.info +++ b/src/version.info @@ -1 +1 @@ -2.4.0 +2.5.0rc0 diff --git a/tests/parity_fabric/test_parity_ddp.py b/tests/parity_fabric/test_parity_ddp.py index 217d401ad6fba..aebd9064b31fd 100644 --- a/tests/parity_fabric/test_parity_ddp.py +++ b/tests/parity_fabric/test_parity_ddp.py @@ -162,5 +162,8 @@ def run_parity_test(accelerator: str = "cpu", devices: int = 2, tolerance: float if __name__ == "__main__": from jsonargparse import CLI + from lightning.pytorch.cli import patch_jsonargparse_python_3_12_8 + + patch_jsonargparse_python_3_12_8() # Required until fix https://github.com/omni-us/jsonargparse/issues/641 CLI(run_parity_test) diff --git a/tests/run_standalone_tests.sh b/tests/run_standalone_tests.sh index 0aa0bacff168a..75a52e16c57dc 100755 --- a/tests/run_standalone_tests.sh +++ b/tests/run_standalone_tests.sh @@ -17,8 +17,13 @@ set -e # Batch size for testing: Determines how many standalone test invocations run in parallel # It can be set through the env variable PL_STANDALONE_TESTS_BATCH_SIZE and defaults to 6 if not set -test_batch_size="${PL_STANDALONE_TESTS_BATCH_SIZE:-6}" +test_batch_size="${PL_STANDALONE_TESTS_BATCH_SIZE:-3}" source="${PL_STANDALONE_TESTS_SOURCE:-"lightning"}" +# this is the directory where the tests are located +test_dir=$1 # parse the first argument +COLLECTED_TESTS_FILE="collected_tests.txt" + +ls -lh . # show the contents of the directory # this environment variable allows special tests to run export PL_RUN_STANDALONE_TESTS=1 @@ -26,71 +31,87 @@ export PL_RUN_STANDALONE_TESTS=1 defaults=" -m coverage run --source ${source} --append -m pytest --no-header -v -s --timeout 120 " echo "Using defaults: ${defaults}" -# get the testing location as the first argument -test_path=$1 -printf "source path: $test_path\n" - -# collect all tests with parametrization based filtering with PL_RUN_STANDALONE_TESTS -standalone_tests=$(python3 -m pytest $test_path -q --collect-only --pythonwarnings ignore) -printf "Collected tests: \n $standalone_tests\n" -# match only lines with tests -parametrizations=$(perl -nle 'print $& while m{\S+::test_\S+}g' <<< "$standalone_tests") -# convert the list to be array -parametrizations_arr=($parametrizations) -report='' +# get the list of parametrizations. we need to call them separately. the last two lines are removed. +# note: if there's a syntax error, this will fail with some garbled output +python3 -um pytest $test_dir -q --collect-only --pythonwarnings ignore 2>&1 > $COLLECTED_TESTS_FILE +# early terminate if collection failed (e.g. syntax error) +if [[ $? != 0 ]]; then + cat $COLLECTED_TESTS_FILE + exit 1 +fi -rm -f standalone_test_output.txt # in case it exists, remove it -rm -f testnames.txt +# removes the last line of the file +sed -i '$d' $COLLECTED_TESTS_FILE -function show_batched_output { - if [ -f standalone_test_output.txt ]; then # if exists - cat standalone_test_output.txt - # heuristic: stop if there's mentions of errors. this can prevent false negatives when only some of the ranks fail - if perl -nle 'print if /error|(?> testnames.txt - # fix the port to avoid race condition when batched distributed tests select the port randomly - export MASTER_PORT=$((29500 + $i % $test_batch_size)) +status=0 # reset the script status +report="" # final report +pids=() # array of PID for running tests +test_ids=() # array of indexes of running tests +printf "Running $test_count tests in batches of $test_batch_size\n" +for i in "${!tests[@]}"; do + # remove initial "tests/" from the test name + test=${tests[$i]/tests\//} + printf "Running test $((i+1))/$test_count: $test\n" # execute the test in the background - # redirect to a log file that buffers test output. since the tests will run in the background, we cannot let them - # output to std{out,err} because the outputs would be garbled together - python3 ${defaults} "$parametrization" &>> standalone_test_output.txt & - # save the PID in an array - pids[${i}]=$! - # add row to the final report - report+="Ran\t$parametrization\n" + # redirect to a log file that buffers test output. since the tests will run in the background, + # we cannot let them output to std{out,err} because the outputs would be garbled together + python3 ${defaults} "$test" 2>&1 > "standalone_test_output-$i.txt" & + test_ids+=($i) # save the test's id in an array with running tests + pids+=($!) # save the PID in an array with running tests - if ((($i + 1) % $test_batch_size == 0)); then + # if we reached the batch size, wait for all tests to finish + if (( (($i + 1) % $test_batch_size == 0) || $i == $test_count-1 )); then + printf "Waiting for batch to finish: $(IFS=' '; echo "${pids[@]}")\n" # wait for running tests - for pid in ${pids[*]}; do wait $pid; done - unset pids # empty the array - show_batched_output + for j in "${!test_ids[@]}"; do + i=${test_ids[$j]} # restore the global test's id + pid=${pids[$j]} # restore the particular PID + test=${tests[$i]} # restore the test name + printf "Waiting for $tests >> standalone_test_output-$i.txt (PID: $pid)\n" + wait -n $pid + # get the exit status of the test + test_status=$? + # add row to the final report + report+="Ran\t$test\t>> exit:$test_status\n" + if [[ $test_status != 0 ]]; then + # show the output of the failed test + cat "standalone_test_output-$i.txt" + # Process exited with a non-zero exit status + status=$test_status + fi + done + test_ids=() # reset the test's id array + pids=() # reset the PID array fi done -# wait for leftover tests -for pid in ${pids[*]}; do wait $pid; done -show_batched_output # echo test report printf '=%.s' {1..80} printf "\n$report" printf '=%.s' {1..80} printf '\n' + +# exit with the worst test result +exit $status diff --git a/tests/tests_fabric/accelerators/test_cuda.py b/tests/tests_fabric/accelerators/test_cuda.py index e323ada908cd1..0aed3675d93e1 100644 --- a/tests/tests_fabric/accelerators/test_cuda.py +++ b/tests/tests_fabric/accelerators/test_cuda.py @@ -121,27 +121,32 @@ def test_tf32_message(_, __, ___, caplog, monkeypatch): def test_find_usable_cuda_devices_error_handling(): """Test error handling for edge cases when using `find_usable_cuda_devices`.""" # Asking for GPUs if no GPUs visible - with mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=0), pytest.raises( - ValueError, match="You requested to find 2 devices but there are no visible CUDA" + with ( + mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=0), + pytest.raises(ValueError, match="You requested to find 2 devices but there are no visible CUDA"), ): find_usable_cuda_devices(2) # Asking for more GPUs than are visible - with mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=1), pytest.raises( - ValueError, match="this machine only has 1 GPUs" + with ( + mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=1), + pytest.raises(ValueError, match="this machine only has 1 GPUs"), ): find_usable_cuda_devices(2) # All GPUs are unusable tensor_mock = Mock(side_effect=RuntimeError) # simulate device placement fails - with mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2), mock.patch( - "lightning.fabric.accelerators.cuda.torch.tensor", tensor_mock - ), pytest.raises(RuntimeError, match=escape("The devices [0, 1] are occupied by other processes")): + with ( + mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2), + mock.patch("lightning.fabric.accelerators.cuda.torch.tensor", tensor_mock), + pytest.raises(RuntimeError, match=escape("The devices [0, 1] are occupied by other processes")), + ): find_usable_cuda_devices(2) # Request for as many GPUs as there are, no error should be raised - with mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=5), mock.patch( - "lightning.fabric.accelerators.cuda.torch.tensor" + with ( + mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=5), + mock.patch("lightning.fabric.accelerators.cuda.torch.tensor"), ): assert find_usable_cuda_devices(-1) == [0, 1, 2, 3, 4] diff --git a/tests/tests_fabric/accelerators/test_registry.py b/tests/tests_fabric/accelerators/test_registry.py index e8f39b6e83406..2544df1e01ff8 100644 --- a/tests/tests_fabric/accelerators/test_registry.py +++ b/tests/tests_fabric/accelerators/test_registry.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any import torch from lightning.fabric.accelerators import ACCELERATOR_REGISTRY, Accelerator @@ -30,7 +30,7 @@ def __init__(self, param1, param2): def setup_device(self, device: torch.device) -> None: pass - def get_device_stats(self, device: torch.device) -> Dict[str, Any]: + def get_device_stats(self, device: torch.device) -> dict[str, Any]: pass def teardown(self) -> None: diff --git a/tests/tests_fabric/accelerators/test_xla.py b/tests/tests_fabric/accelerators/test_xla.py index 1af7d7e1e7206..7a906c8ae0c54 100644 --- a/tests/tests_fabric/accelerators/test_xla.py +++ b/tests/tests_fabric/accelerators/test_xla.py @@ -44,3 +44,8 @@ def test_get_parallel_devices_raises(tpu_available): XLAAccelerator.get_parallel_devices(5) with pytest.raises(ValueError, match="Could not parse.*anything-else'"): XLAAccelerator.get_parallel_devices("anything-else") + + +@pytest.mark.skipif(not _XLA_AVAILABLE, reason="test requires torch_xla to be present") +def test_instantiate_xla_accelerator(): + _ = XLAAccelerator() diff --git a/tests/tests_fabric/conftest.py b/tests/tests_fabric/conftest.py index 446994167d0a1..dd272257b3923 100644 --- a/tests/tests_fabric/conftest.py +++ b/tests/tests_fabric/conftest.py @@ -14,8 +14,8 @@ import os import sys import threading +from concurrent.futures.process import _ExecutorManagerThread from pathlib import Path -from typing import List from unittest.mock import Mock import lightning.fabric @@ -25,9 +25,6 @@ from lightning.fabric.strategies.launchers.subprocess_script import _ChildProcessObserver from lightning.fabric.utilities.distributed import _destroy_dist_connection -if sys.version_info >= (3, 9): - from concurrent.futures.process import _ExecutorManagerThread - @pytest.fixture(autouse=True) def preserve_global_rank_variable(): @@ -69,6 +66,7 @@ def restore_env_variables(): "OMP_NUM_THREADS", # set by our launchers # set by torchdynamo "TRITON_CACHE_DIR", + "TORCHINDUCTOR_CACHE_DIR", } leaked_vars.difference_update(allowlist) assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}" @@ -200,7 +198,7 @@ def leave_no_artifacts_behind(): assert not difference, f"Test left artifacts behind: {difference}" -def pytest_collection_modifyitems(items: List[pytest.Function], config: pytest.Config) -> None: +def pytest_collection_modifyitems(items: list[pytest.Function], config: pytest.Config) -> None: """An adaptation of `tests/tests_pytorch/conftest.py::pytest_collection_modifyitems`""" initial_size = len(items) conditions = [] diff --git a/tests/tests_fabric/helpers/datasets.py b/tests/tests_fabric/helpers/datasets.py index 211e1f36a9ab5..ee14b21dc546c 100644 --- a/tests/tests_fabric/helpers/datasets.py +++ b/tests/tests_fabric/helpers/datasets.py @@ -1,4 +1,4 @@ -from typing import Iterator +from collections.abc import Iterator import torch from torch import Tensor diff --git a/tests/tests_fabric/plugins/collectives/test_torch_collective.py b/tests/tests_fabric/plugins/collectives/test_torch_collective.py index c8deb9d335161..b4c223e770282 100644 --- a/tests/tests_fabric/plugins/collectives/test_torch_collective.py +++ b/tests/tests_fabric/plugins/collectives/test_torch_collective.py @@ -29,13 +29,16 @@ @contextlib.contextmanager def check_destroy_group(): - with mock.patch( - "lightning.fabric.plugins.collectives.torch_collective.TorchCollective.new_group", - wraps=TorchCollective.new_group, - ) as mock_new, mock.patch( - "lightning.fabric.plugins.collectives.torch_collective.TorchCollective.destroy_group", - wraps=TorchCollective.destroy_group, - ) as mock_destroy: + with ( + mock.patch( + "lightning.fabric.plugins.collectives.torch_collective.TorchCollective.new_group", + wraps=TorchCollective.new_group, + ) as mock_new, + mock.patch( + "lightning.fabric.plugins.collectives.torch_collective.TorchCollective.destroy_group", + wraps=TorchCollective.destroy_group, + ) as mock_destroy, + ): yield # 0 to account for tests that mock distributed # -1 to account for destroying the default process group @@ -155,9 +158,10 @@ def test_repeated_create_and_destroy(): with pytest.raises(RuntimeError, match="TorchCollective` already owns a group"): collective.create_group() - with mock.patch.dict("torch.distributed.distributed_c10d._pg_map", {collective.group: ("", None)}), mock.patch( - "torch.distributed.destroy_process_group" - ) as destroy_mock: + with ( + mock.patch.dict("torch.distributed.distributed_c10d._pg_map", {collective.group: ("", None)}), + mock.patch("torch.distributed.destroy_process_group") as destroy_mock, + ): collective.teardown() # this would be called twice if `init_process_group` wasn't patched. once for the group and once for the default # group @@ -300,9 +304,11 @@ def test_collective_manages_default_group(): assert TorchCollective.manages_default_group - with mock.patch.object(collective, "_group") as mock_group, mock.patch.dict( - "torch.distributed.distributed_c10d._pg_map", {mock_group: ("", None)} - ), mock.patch("torch.distributed.destroy_process_group") as destroy_mock: + with ( + mock.patch.object(collective, "_group") as mock_group, + mock.patch.dict("torch.distributed.distributed_c10d._pg_map", {mock_group: ("", None)}), + mock.patch("torch.distributed.destroy_process_group") as destroy_mock, + ): collective.teardown() destroy_mock.assert_called_once_with(mock_group) diff --git a/tests/tests_fabric/plugins/environments/test_lsf.py b/tests/tests_fabric/plugins/environments/test_lsf.py index b444f6fc4d781..4e60d968dc953 100644 --- a/tests/tests_fabric/plugins/environments/test_lsf.py +++ b/tests/tests_fabric/plugins/environments/test_lsf.py @@ -41,8 +41,9 @@ def test_empty_lsb_djob_rankfile(): def test_missing_lsb_job_id(tmp_path): """Test an error when the job id cannot be found.""" - with mock.patch.dict(os.environ, {"LSB_DJOB_RANKFILE": _make_rankfile(tmp_path)}), pytest.raises( - ValueError, match="Could not find job id in environment variable LSB_JOBID" + with ( + mock.patch.dict(os.environ, {"LSB_DJOB_RANKFILE": _make_rankfile(tmp_path)}), + pytest.raises(ValueError, match="Could not find job id in environment variable LSB_JOBID"), ): LSFEnvironment() diff --git a/tests/tests_fabric/plugins/environments/test_slurm.py b/tests/tests_fabric/plugins/environments/test_slurm.py index 73457ede41298..f237478a533f4 100644 --- a/tests/tests_fabric/plugins/environments/test_slurm.py +++ b/tests/tests_fabric/plugins/environments/test_slurm.py @@ -155,8 +155,9 @@ def test_srun_variable_validation(): """Test that we raise useful errors when `srun` variables are misconfigured.""" with mock.patch.dict(os.environ, {"SLURM_NTASKS": "1"}): SLURMEnvironment() - with mock.patch.dict(os.environ, {"SLURM_NTASKS": "2"}), pytest.raises( - RuntimeError, match="You set `--ntasks=2` in your SLURM" + with ( + mock.patch.dict(os.environ, {"SLURM_NTASKS": "2"}), + pytest.raises(RuntimeError, match="You set `--ntasks=2` in your SLURM"), ): SLURMEnvironment() diff --git a/tests/tests_fabric/plugins/precision/test_fsdp.py b/tests/tests_fabric/plugins/precision/test_fsdp.py index e42df493dd725..6a4968736ea86 100644 --- a/tests/tests_fabric/plugins/precision/test_fsdp.py +++ b/tests/tests_fabric/plugins/precision/test_fsdp.py @@ -127,3 +127,21 @@ def test_invalid_precision_with_fsdp_precision(): with pytest.raises(ValueError, match="is not supported in FSDP. `precision` must be one of"): FSDPPrecision(precision="64-true") + + +@pytest.mark.parametrize( + ("precision", "expected_dtype"), + [ + ("32-true", torch.float32), + ("bf16-mixed", torch.float32), + ("16-mixed", torch.float32), + ("bf16-true", torch.bfloat16), + ("16-true", torch.float16), + ], +) +def test_convert_module(precision, expected_dtype): + precision = FSDPPrecision(precision=precision) + module = torch.nn.Linear(2, 2) + assert module.weight.dtype == module.bias.dtype == torch.float32 + module = precision.convert_module(module) + assert module.weight.dtype == module.bias.dtype == expected_dtype diff --git a/tests/tests_fabric/strategies/launchers/test_multiprocessing.py b/tests/tests_fabric/strategies/launchers/test_multiprocessing.py index b63d443098eac..6c595fba7acab 100644 --- a/tests/tests_fabric/strategies/launchers/test_multiprocessing.py +++ b/tests/tests_fabric/strategies/launchers/test_multiprocessing.py @@ -100,8 +100,11 @@ def test_check_for_bad_cuda_fork(mp_mock, _, start_method): def test_check_for_missing_main_guard(): launcher = _MultiProcessingLauncher(strategy=Mock(), start_method="spawn") - with mock.patch( - "lightning.fabric.strategies.launchers.multiprocessing.mp.current_process", - return_value=Mock(_inheriting=True), # pretend that main is importing itself - ), pytest.raises(RuntimeError, match="requires that your script guards the main"): + with ( + mock.patch( + "lightning.fabric.strategies.launchers.multiprocessing.mp.current_process", + return_value=Mock(_inheriting=True), # pretend that main is importing itself + ), + pytest.raises(RuntimeError, match="requires that your script guards the main"), + ): launcher.launch(function=Mock()) diff --git a/tests/tests_fabric/strategies/test_ddp.py b/tests/tests_fabric/strategies/test_ddp.py index 56d9875dfefed..b98d5f8226dc2 100644 --- a/tests/tests_fabric/strategies/test_ddp.py +++ b/tests/tests_fabric/strategies/test_ddp.py @@ -58,9 +58,12 @@ def test_ddp_no_backward_sync(): strategy = DDPStrategy() assert isinstance(strategy._backward_sync_control, _DDPBackwardSyncControl) - with pytest.raises( - TypeError, match="is only possible if the module passed to .* is wrapped in `DistributedDataParallel`" - ), strategy._backward_sync_control.no_backward_sync(Mock(), True): + with ( + pytest.raises( + TypeError, match="is only possible if the module passed to .* is wrapped in `DistributedDataParallel`" + ), + strategy._backward_sync_control.no_backward_sync(Mock(), True), + ): pass module = MagicMock(spec=DistributedDataParallel) diff --git a/tests/tests_fabric/strategies/test_deepspeed_integration.py b/tests/tests_fabric/strategies/test_deepspeed_integration.py index 3be535effa078..4811599ed05ab 100644 --- a/tests/tests_fabric/strategies/test_deepspeed_integration.py +++ b/tests/tests_fabric/strategies/test_deepspeed_integration.py @@ -404,9 +404,11 @@ def test_deepspeed_init_module_with_stages_1_2(stage, empty_init): fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy, precision="bf16-true") fabric.launch() - with mock.patch("deepspeed.zero.Init") as zero_init_mock, mock.patch( - "torch.Tensor.uniform_" - ) as init_mock, fabric.init_module(empty_init=empty_init): + with ( + mock.patch("deepspeed.zero.Init") as zero_init_mock, + mock.patch("torch.Tensor.uniform_") as init_mock, + fabric.init_module(empty_init=empty_init), + ): model = BoringModel() zero_init_mock.assert_called_with(enabled=False, remote_device=None, config_dict_or_path=ANY) diff --git a/tests/tests_fabric/strategies/test_fsdp.py b/tests/tests_fabric/strategies/test_fsdp.py index 0c46e7ac1763c..cb6542cdb6243 100644 --- a/tests/tests_fabric/strategies/test_fsdp.py +++ b/tests/tests_fabric/strategies/test_fsdp.py @@ -133,9 +133,12 @@ def test_no_backward_sync(): strategy = FSDPStrategy() assert isinstance(strategy._backward_sync_control, _FSDPBackwardSyncControl) - with pytest.raises( - TypeError, match="is only possible if the module passed to .* is wrapped in `FullyShardedDataParallel`" - ), strategy._backward_sync_control.no_backward_sync(Mock(), True): + with ( + pytest.raises( + TypeError, match="is only possible if the module passed to .* is wrapped in `FullyShardedDataParallel`" + ), + strategy._backward_sync_control.no_backward_sync(Mock(), True), + ): pass module = MagicMock(spec=FullyShardedDataParallel) @@ -172,9 +175,12 @@ def __init__(self): assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy) strategy._parallel_devices = [torch.device("cuda", 0)] - with mock.patch("torch.distributed.fsdp.FullyShardedDataParallel", new=MagicMock), mock.patch( - "torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing" - ) as apply_mock: + with ( + mock.patch("torch.distributed.fsdp.FullyShardedDataParallel", new=MagicMock), + mock.patch( + "torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing" + ) as apply_mock, + ): wrapped = strategy.setup_module(Model()) apply_mock.assert_called_with(wrapped, checkpoint_wrapper_fn=ANY, **strategy._activation_checkpointing_kwargs) diff --git a/tests/tests_fabric/strategies/test_model_parallel_integration.py b/tests/tests_fabric/strategies/test_model_parallel_integration.py index dfbdb16b10060..b04a29b691529 100644 --- a/tests/tests_fabric/strategies/test_model_parallel_integration.py +++ b/tests/tests_fabric/strategies/test_model_parallel_integration.py @@ -29,6 +29,13 @@ from tests_fabric.helpers.runif import RunIf +@pytest.fixture() +def distributed(): + yield + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + class FeedForward(nn.Module): def __init__(self): super().__init__() @@ -81,7 +88,7 @@ def _parallelize_feed_forward_fsdp2_tp(model, device_mesh): @RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4) -def test_setup_device_mesh(): +def test_setup_device_mesh(distributed): from torch.distributed.device_mesh import DeviceMesh for dp_size, tp_size in ((1, 4), (4, 1), (2, 2)): @@ -116,11 +123,28 @@ def test_setup_device_mesh(): assert fabric.strategy.device_mesh.size(1) == 4 +def _parallelize_with_compile(parallelize): + def fn(model, device_mesh): + model = parallelize(model, device_mesh) + return torch.compile(model) + + return fn + + @RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=2) -def test_tensor_parallel(): +@pytest.mark.parametrize( + "compile", + [True, False], +) +def test_tensor_parallel(distributed, compile): from torch.distributed._tensor import DTensor - strategy = ModelParallelStrategy(parallelize_fn=_parallelize_feed_forward_tp) + parallelize = _parallelize_feed_forward_tp + + if compile: + parallelize = _parallelize_with_compile(parallelize) + + strategy = ModelParallelStrategy(parallelize_fn=parallelize) fabric = Fabric(accelerator="auto", devices=2, strategy=strategy) fabric.launch() @@ -161,9 +185,18 @@ def test_tensor_parallel(): @RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4) -def test_fsdp2_tensor_parallel(): +@pytest.mark.parametrize( + "compile", + [True, False], +) +def test_fsdp2_tensor_parallel(distributed, compile): from torch.distributed._tensor import DTensor + parallelize = _parallelize_feed_forward_fsdp2_tp + + if compile: + parallelize = _parallelize_with_compile(parallelize) + strategy = ModelParallelStrategy( parallelize_fn=_parallelize_feed_forward_fsdp2_tp, data_parallel_size=2, @@ -238,6 +271,7 @@ def _train(fabric, model=None, optimizer=None): @RunIf(min_torch="2.4", min_cuda_gpus=4, standalone=True) +@pytest.mark.filterwarnings("ignore::UserWarning") @pytest.mark.parametrize( "precision", [ @@ -245,7 +279,7 @@ def _train(fabric, model=None, optimizer=None): pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True)), ], ) -def test_train_save_load(precision, tmp_path): +def test_train_save_load(distributed, precision, tmp_path): """Test 2D-parallel training, saving and loading precision settings.""" strategy = ModelParallelStrategy( _parallelize_feed_forward_fsdp2_tp, @@ -303,7 +337,7 @@ def test_train_save_load(precision, tmp_path): @pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True) -def test_save_full_state_dict(tmp_path): +def test_save_full_state_dict(distributed, tmp_path): """Test that ModelParallelStrategy saves the full state into a single file with `save_distributed_checkpoint=False`.""" from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict @@ -404,7 +438,7 @@ def test_save_full_state_dict(tmp_path): @pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True) -def test_load_full_state_dict_into_sharded_model(tmp_path): +def test_load_full_state_dict_into_sharded_model(distributed, tmp_path): """Test that the strategy can load a full-state checkpoint into a distributed model.""" fabric = Fabric(accelerator="cuda", devices=1) fabric.seed_everything(0) @@ -450,7 +484,7 @@ def test_load_full_state_dict_into_sharded_model(tmp_path): @RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True) @pytest.mark.parametrize("move_to_device", [True, False]) @mock.patch("lightning.fabric.wrappers._FabricModule") -def test_setup_module_move_to_device(fabric_module_mock, move_to_device): +def test_setup_module_move_to_device(fabric_module_mock, move_to_device, distributed): """Test that `move_to_device` does nothing, ModelParallel decides which device parameters get moved to which device (sharding).""" from torch.distributed._tensor import DTensor @@ -482,7 +516,7 @@ def test_setup_module_move_to_device(fabric_module_mock, move_to_device): pytest.param("bf16-true", torch.bfloat16, marks=RunIf(bf16_cuda=True)), ], ) -def test_module_init_context(precision, expected_dtype): +def test_module_init_context(distributed, precision, expected_dtype): """Test that the module under the init-context gets moved to the right device and dtype.""" strategy = ModelParallelStrategy(parallelize_fn=_parallelize_feed_forward_fsdp2) fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy, precision=precision) @@ -505,7 +539,7 @@ def _run_setup_assertions(empty_init, expected_device): @RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True) -def test_save_filter(tmp_path): +def test_save_filter(distributed, tmp_path): strategy = ModelParallelStrategy( parallelize_fn=_parallelize_feed_forward_fsdp2, save_distributed_checkpoint=False, @@ -558,7 +592,7 @@ def _parallelize_single_linear_tp_fsdp2(model, device_mesh): "val", ], ) -def test_clip_gradients(clip_type, precision): +def test_clip_gradients(distributed, clip_type, precision): strategy = ModelParallelStrategy(_parallelize_single_linear_tp_fsdp2) fabric = Fabric(accelerator="auto", devices=2, precision=precision, strategy=strategy) fabric.launch() @@ -600,7 +634,7 @@ def test_clip_gradients(clip_type, precision): @RunIf(min_torch="2.4", min_cuda_gpus=4, standalone=True) -def test_save_sharded_and_consolidate_and_load(tmp_path): +def test_save_sharded_and_consolidate_and_load(distributed, tmp_path): """Test the consolidation of a distributed (DTensor) checkpoint into a single file.""" strategy = ModelParallelStrategy( _parallelize_feed_forward_fsdp2_tp, @@ -657,7 +691,7 @@ def test_save_sharded_and_consolidate_and_load(tmp_path): @RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True) -def test_load_raw_module_state(): +def test_load_raw_module_state(distributed): from torch.distributed.device_mesh import init_device_mesh from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module diff --git a/tests/tests_fabric/strategies/test_xla_fsdp.py b/tests/tests_fabric/strategies/test_xla_fsdp.py index e2864b684c4a7..879a55cf77f34 100644 --- a/tests/tests_fabric/strategies/test_xla_fsdp.py +++ b/tests/tests_fabric/strategies/test_xla_fsdp.py @@ -48,9 +48,12 @@ def test_xla_fsdp_no_backward_sync(): strategy = XLAFSDPStrategy() assert isinstance(strategy._backward_sync_control, _XLAFSDPBackwardSyncControl) - with pytest.raises( - TypeError, match="is only possible if the module passed to .* is wrapped in `XlaFullyShardedDataParallel`" - ), strategy._backward_sync_control.no_backward_sync(object(), True): + with ( + pytest.raises( + TypeError, match="is only possible if the module passed to .* is wrapped in `XlaFullyShardedDataParallel`" + ), + strategy._backward_sync_control.no_backward_sync(object(), True), + ): pass module = MagicMock(spec=XlaFullyShardedDataParallel) diff --git a/tests/tests_fabric/test_connector.py b/tests/tests_fabric/test_connector.py index 08d6dbb45ed91..8a6e9206b3df5 100644 --- a/tests/tests_fabric/test_connector.py +++ b/tests/tests_fabric/test_connector.py @@ -15,7 +15,7 @@ import os import sys from contextlib import nullcontext -from typing import Any, Dict +from typing import Any from unittest import mock from unittest.mock import Mock @@ -165,7 +165,7 @@ class Accel(Accelerator): def setup_device(self, device: torch.device) -> None: pass - def get_device_stats(self, device: torch.device) -> Dict[str, Any]: + def get_device_stats(self, device: torch.device) -> dict[str, Any]: pass def teardown(self) -> None: @@ -960,28 +960,33 @@ def test_arguments_from_environment_collision(): with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"}): _Connector(accelerator="cuda") - with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu", "LT_CLI_USED": "1"}), pytest.raises( - ValueError, match="`Fabric\\(accelerator='cuda', ...\\)` but .* `--accelerator=cpu`" + with ( + mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu", "LT_CLI_USED": "1"}), + pytest.raises(ValueError, match="`Fabric\\(accelerator='cuda', ...\\)` but .* `--accelerator=cpu`"), ): _Connector(accelerator="cuda") - with mock.patch.dict(os.environ, {"LT_STRATEGY": "ddp", "LT_CLI_USED": "1"}), pytest.raises( - ValueError, match="`Fabric\\(strategy='ddp_spawn', ...\\)` but .* `--strategy=ddp`" + with ( + mock.patch.dict(os.environ, {"LT_STRATEGY": "ddp", "LT_CLI_USED": "1"}), + pytest.raises(ValueError, match="`Fabric\\(strategy='ddp_spawn', ...\\)` but .* `--strategy=ddp`"), ): _Connector(strategy="ddp_spawn") - with mock.patch.dict(os.environ, {"LT_DEVICES": "2", "LT_CLI_USED": "1"}), pytest.raises( - ValueError, match="`Fabric\\(devices=3, ...\\)` but .* `--devices=2`" + with ( + mock.patch.dict(os.environ, {"LT_DEVICES": "2", "LT_CLI_USED": "1"}), + pytest.raises(ValueError, match="`Fabric\\(devices=3, ...\\)` but .* `--devices=2`"), ): _Connector(devices=3) - with mock.patch.dict(os.environ, {"LT_NUM_NODES": "3", "LT_CLI_USED": "1"}), pytest.raises( - ValueError, match="`Fabric\\(num_nodes=2, ...\\)` but .* `--num_nodes=3`" + with ( + mock.patch.dict(os.environ, {"LT_NUM_NODES": "3", "LT_CLI_USED": "1"}), + pytest.raises(ValueError, match="`Fabric\\(num_nodes=2, ...\\)` but .* `--num_nodes=3`"), ): _Connector(num_nodes=2) - with mock.patch.dict(os.environ, {"LT_PRECISION": "16-mixed", "LT_CLI_USED": "1"}), pytest.raises( - ValueError, match="`Fabric\\(precision='64-true', ...\\)` but .* `--precision=16-mixed`" + with ( + mock.patch.dict(os.environ, {"LT_PRECISION": "16-mixed", "LT_CLI_USED": "1"}), + pytest.raises(ValueError, match="`Fabric\\(precision='64-true', ...\\)` but .* `--precision=16-mixed`"), ): _Connector(precision="64-true") diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index 70d04d5431404..7bb6b29eceaf2 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -746,9 +746,10 @@ def test_no_backward_sync(): # pretend that the strategy does not support skipping backward sync fabric._strategy = Mock(spec=ParallelStrategy, _backward_sync_control=None) - with pytest.warns( - PossibleUserWarning, match="The `ParallelStrategy` does not support skipping the" - ), fabric.no_backward_sync(model): + with ( + pytest.warns(PossibleUserWarning, match="The `ParallelStrategy` does not support skipping the"), + fabric.no_backward_sync(model), + ): pass # for single-device strategies, it becomes a no-op without warning diff --git a/tests/tests_fabric/utilities/test_consolidate_checkpoint.py b/tests/tests_fabric/utilities/test_consolidate_checkpoint.py index 216b77e6b9299..2584aab8bdc2e 100644 --- a/tests/tests_fabric/utilities/test_consolidate_checkpoint.py +++ b/tests/tests_fabric/utilities/test_consolidate_checkpoint.py @@ -41,8 +41,9 @@ def test_parse_cli_args(args, expected): def test_process_cli_args(tmp_path, caplog, monkeypatch): # PyTorch version < 2.3 monkeypatch.setattr(lightning.fabric.utilities.consolidate_checkpoint, "_TORCH_GREATER_EQUAL_2_3", False) - with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises( - SystemExit + with ( + caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), + pytest.raises(SystemExit), ): _process_cli_args(Namespace()) assert "requires PyTorch >= 2.3." in caplog.text @@ -51,8 +52,9 @@ def test_process_cli_args(tmp_path, caplog, monkeypatch): # Checkpoint does not exist checkpoint_folder = Path("does/not/exist") - with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises( - SystemExit + with ( + caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), + pytest.raises(SystemExit), ): _process_cli_args(Namespace(checkpoint_folder=checkpoint_folder)) assert f"checkpoint folder does not exist: {checkpoint_folder}" in caplog.text @@ -61,8 +63,9 @@ def test_process_cli_args(tmp_path, caplog, monkeypatch): # Checkpoint exists but is not a folder file = tmp_path / "checkpoint_file" file.touch() - with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises( - SystemExit + with ( + caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), + pytest.raises(SystemExit), ): _process_cli_args(Namespace(checkpoint_folder=file)) assert "checkpoint path must be a folder" in caplog.text @@ -71,8 +74,9 @@ def test_process_cli_args(tmp_path, caplog, monkeypatch): # Checkpoint exists but is not an FSDP checkpoint folder = tmp_path / "checkpoint_folder" folder.mkdir() - with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises( - SystemExit + with ( + caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), + pytest.raises(SystemExit), ): _process_cli_args(Namespace(checkpoint_folder=folder)) assert "Only FSDP-sharded checkpoints saved with Lightning are supported" in caplog.text @@ -89,8 +93,9 @@ def test_process_cli_args(tmp_path, caplog, monkeypatch): # Checkpoint is a FSDP folder, output file already exists file = tmp_path / "ouput_file" file.touch() - with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises( - SystemExit + with ( + caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), + pytest.raises(SystemExit), ): _process_cli_args(Namespace(checkpoint_folder=folder, output_file=file)) assert "path for the converted checkpoint already exists" in caplog.text diff --git a/tests/tests_fabric/utilities/test_distributed.py b/tests/tests_fabric/utilities/test_distributed.py index cc6c23bddbd7b..f5a78a1529a52 100644 --- a/tests/tests_fabric/utilities/test_distributed.py +++ b/tests/tests_fabric/utilities/test_distributed.py @@ -215,9 +215,10 @@ def test_infinite_barrier(): # distributed available barrier = _InfiniteBarrier() - with mock.patch( - "lightning.fabric.utilities.distributed._distributed_is_initialized", return_value=True - ), mock.patch("lightning.fabric.utilities.distributed.torch.distributed") as dist_mock: + with ( + mock.patch("lightning.fabric.utilities.distributed._distributed_is_initialized", return_value=True), + mock.patch("lightning.fabric.utilities.distributed.torch.distributed") as dist_mock, + ): barrier.__enter__() dist_mock.new_group.assert_called_once() assert barrier.barrier == barrier.group.monitored_barrier diff --git a/tests/tests_fabric/utilities/test_imports.py b/tests/tests_fabric/utilities/test_imports.py index 43ee41a7b3035..85408a4ff83bd 100644 --- a/tests/tests_fabric/utilities/test_imports.py +++ b/tests/tests_fabric/utilities/test_imports.py @@ -23,6 +23,13 @@ def test_import_fabric_with_torch_dist_unavailable(): code = dedent( """ import torch + try: + # PyTorch 2.5 relies on torch,distributed._composable.fsdp not + # existing with USE_DISTRIBUTED=0 + import torch._dynamo.variables.functions + torch._dynamo.variables.functions._fsdp_param_group = None + except ImportError: + pass # pretend torch.distributed not available for name in list(torch.distributed.__dict__.keys()): @@ -31,6 +38,11 @@ def test_import_fabric_with_torch_dist_unavailable(): torch.distributed.is_available = lambda: False + # needed for Dynamo in PT 2.5+ compare the torch.distributed source + class _ProcessGroupStub: + pass + torch.distributed.ProcessGroup = _ProcessGroupStub + import lightning.fabric """ ) diff --git a/tests/tests_fabric/utilities/test_throughput.py b/tests/tests_fabric/utilities/test_throughput.py index eefadb285af02..d410d0766d97b 100644 --- a/tests/tests_fabric/utilities/test_throughput.py +++ b/tests/tests_fabric/utilities/test_throughput.py @@ -39,8 +39,9 @@ def test_get_available_flops(xla_available): with pytest.warns(match="not found for 'CocoNut"), mock.patch("torch.cuda.get_device_name", return_value="CocoNut"): assert get_available_flops(torch.device("cuda"), torch.bfloat16) is None - with pytest.warns(match="t4' does not support torch.bfloat"), mock.patch( - "torch.cuda.get_device_name", return_value="t4" + with ( + pytest.warns(match="t4' does not support torch.bfloat"), + mock.patch("torch.cuda.get_device_name", return_value="t4"), ): assert get_available_flops(torch.device("cuda"), torch.bfloat16) is None diff --git a/tests/tests_pytorch/accelerators/test_common.py b/tests/tests_pytorch/accelerators/test_common.py index 7654125af3e44..6967bffd9ffa2 100644 --- a/tests/tests_pytorch/accelerators/test_common.py +++ b/tests/tests_pytorch/accelerators/test_common.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any import torch from lightning.pytorch import Trainer @@ -24,7 +24,7 @@ class TestAccelerator(Accelerator): def setup_device(self, device: torch.device) -> None: pass - def get_device_stats(self, device: torch.device) -> Dict[str, Any]: + def get_device_stats(self, device: torch.device) -> dict[str, Any]: pass def teardown(self) -> None: diff --git a/tests/tests_pytorch/accelerators/test_cpu.py b/tests/tests_pytorch/accelerators/test_cpu.py index cd34fe3f2b318..844556064621d 100644 --- a/tests/tests_pytorch/accelerators/test_cpu.py +++ b/tests/tests_pytorch/accelerators/test_cpu.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Any, Dict, Union +from typing import Any, Union from unittest.mock import Mock import lightning.pytorch as pl @@ -53,7 +53,7 @@ def setup(self, trainer: "pl.Trainer") -> None: def restore_checkpoint_after_setup(self) -> bool: return restore_after_pre_setup - def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: + def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> dict[str, Any]: assert self.setup_called == restore_after_pre_setup return super().load_checkpoint(checkpoint_path) diff --git a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py index de41035d4d832..89c1effe839a8 100644 --- a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py @@ -141,9 +141,12 @@ def on_train_start(self) -> None: model = TestModel() - with mock.patch( - "lightning.pytorch.callbacks.progress.rich_progress.Progress.stop", autospec=True - ) as mock_progress_stop, pytest.raises(SystemExit): + with ( + mock.patch( + "lightning.pytorch.callbacks.progress.rich_progress.Progress.stop", autospec=True + ) as mock_progress_stop, + pytest.raises(SystemExit), + ): progress_bar = RichProgressBar() trainer = Trainer( default_root_dir=tmp_path, @@ -308,20 +311,6 @@ def test_rich_progress_bar_counter_with_val_check_interval(tmp_path): assert val_bar.total == 4 -@RunIf(rich=True) -@mock.patch("lightning.pytorch.callbacks.progress.rich_progress._detect_light_colab_theme", return_value=True) -def test_rich_progress_bar_colab_light_theme_update(*_): - theme = RichProgressBar().theme - assert theme.description == "black" - assert theme.batch_progress == "black" - assert theme.metrics == "black" - - theme = RichProgressBar(theme=RichProgressBarTheme(description="blue", metrics="red")).theme - assert theme.description == "blue" - assert theme.batch_progress == "black" - assert theme.metrics == "red" - - @RunIf(rich=True) def test_rich_progress_bar_metric_display_task_id(tmp_path): class CustomModel(BoringModel): diff --git a/tests/tests_pytorch/callbacks/test_device_stats_monitor.py b/tests/tests_pytorch/callbacks/test_device_stats_monitor.py index aacee958faa45..f1d999f1df61a 100644 --- a/tests/tests_pytorch/callbacks/test_device_stats_monitor.py +++ b/tests/tests_pytorch/callbacks/test_device_stats_monitor.py @@ -14,7 +14,7 @@ import csv import os import re -from typing import Dict, Optional +from typing import Optional from unittest import mock from unittest.mock import Mock @@ -40,7 +40,7 @@ def test_device_stats_gpu_from_torch(tmp_path): class DebugLogger(CSVLogger): @rank_zero_only - def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: + def log_metrics(self, metrics: dict[str, float], step: Optional[int] = None) -> None: fields = [ "allocated_bytes.all.freed", "inactive_split.all.peak", @@ -74,7 +74,7 @@ def test_device_stats_cpu(cpu_stats_mock, tmp_path, cpu_stats): CPU_METRIC_KEYS = (_CPU_VM_PERCENT, _CPU_SWAP_PERCENT, _CPU_PERCENT) class DebugLogger(CSVLogger): - def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: + def log_metrics(self, metrics: dict[str, float], step: Optional[int] = None) -> None: enabled = cpu_stats is not False for f in CPU_METRIC_KEYS: has_cpu_metrics = any(f in h for h in metrics) diff --git a/tests/tests_pytorch/callbacks/test_early_stopping.py b/tests/tests_pytorch/callbacks/test_early_stopping.py index 633c1dc0853e0..a3d56bb0135c3 100644 --- a/tests/tests_pytorch/callbacks/test_early_stopping.py +++ b/tests/tests_pytorch/callbacks/test_early_stopping.py @@ -15,15 +15,13 @@ import math import os import pickle -from contextlib import nullcontext -from typing import List, Optional +from typing import Optional from unittest import mock from unittest.mock import Mock import cloudpickle import pytest import torch -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel @@ -193,13 +191,11 @@ def test_pickling(): early_stopping = EarlyStopping(monitor="foo") early_stopping_pickled = pickle.dumps(early_stopping) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): - early_stopping_loaded = pickle.loads(early_stopping_pickled) + early_stopping_loaded = pickle.loads(early_stopping_pickled) assert vars(early_stopping) == vars(early_stopping_loaded) early_stopping_pickled = cloudpickle.dumps(early_stopping) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): - early_stopping_loaded = cloudpickle.loads(early_stopping_pickled) + early_stopping_loaded = cloudpickle.loads(early_stopping_pickled) assert vars(early_stopping) == vars(early_stopping_loaded) @@ -411,7 +407,7 @@ def on_train_end(self) -> None: ) def test_multiple_early_stopping_callbacks( tmp_path, - callbacks: List[EarlyStopping], + callbacks: list[EarlyStopping], expected_stop_epoch: int, check_on_train_epoch_end: bool, strategy: str, diff --git a/tests/tests_pytorch/callbacks/test_model_summary.py b/tests/tests_pytorch/callbacks/test_model_summary.py index b42907dc9a38d..215176ee2376b 100644 --- a/tests/tests_pytorch/callbacks/test_model_summary.py +++ b/tests/tests_pytorch/callbacks/test_model_summary.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Tuple +from typing import Any from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelSummary @@ -45,7 +45,7 @@ def test_custom_model_summary_callback_summarize(tmp_path): class CustomModelSummary(ModelSummary): @staticmethod def summarize( - summary_data: List[Tuple[str, List[str]]], + summary_data: list[tuple[str, list[str]]], total_parameters: int, trainable_parameters: int, model_size: float, diff --git a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py index d57ac76f04400..8d3a1800e1fa2 100644 --- a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py +++ b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py @@ -13,8 +13,9 @@ # limitations under the License. import logging import os +from contextlib import AbstractContextManager from pathlib import Path -from typing import ContextManager, Optional +from typing import Optional from unittest import mock import pytest @@ -382,5 +383,5 @@ def test_misconfiguration_error_with_sharded_model(tmp_path, strategy: str): trainer.fit(model) -def _backward_patch(trainer: Trainer) -> ContextManager: +def _backward_patch(trainer: Trainer) -> AbstractContextManager: return mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward) diff --git a/tests/tests_pytorch/callbacks/test_throughput_monitor.py b/tests/tests_pytorch/callbacks/test_throughput_monitor.py index a74efba75813b..4867134a85642 100644 --- a/tests/tests_pytorch/callbacks/test_throughput_monitor.py +++ b/tests/tests_pytorch/callbacks/test_throughput_monitor.py @@ -43,8 +43,9 @@ def test_throughput_monitor_fit(tmp_path): ) # these timing results are meant to precisely match the `test_throughput_monitor` test in fabric timings = [0.0] + [0.5 + i for i in range(1, 6)] - with mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100), mock.patch( - "time.perf_counter", side_effect=timings + with ( + mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100), + mock.patch("time.perf_counter", side_effect=timings), ): trainer.fit(model) @@ -179,8 +180,9 @@ def test_throughput_monitor_fit_gradient_accumulation(log_every_n_steps, tmp_pat enable_progress_bar=False, ) timings = [0.0] + [0.5 + i for i in range(1, 11)] - with mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100), mock.patch( - "time.perf_counter", side_effect=timings + with ( + mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100), + mock.patch("time.perf_counter", side_effect=timings), ): trainer.fit(model) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 8ef78a742f9a7..ef0abc7c463a8 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -17,7 +17,6 @@ import re import time from argparse import Namespace -from contextlib import nullcontext from datetime import timedelta from inspect import signature from pathlib import Path @@ -30,11 +29,10 @@ import pytest import torch import yaml -from jsonargparse import ArgumentParser from lightning.fabric.utilities.cloud_io import _load as pl_load -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.cli import LightningArgumentParser as ArgumentParser from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger from lightning.pytorch.utilities.exceptions import MisconfigurationException @@ -352,13 +350,11 @@ def test_pickling(tmp_path): ckpt = ModelCheckpoint(dirpath=tmp_path) ckpt_pickled = pickle.dumps(ckpt) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): - ckpt_loaded = pickle.loads(ckpt_pickled) + ckpt_loaded = pickle.loads(ckpt_pickled) assert vars(ckpt) == vars(ckpt_loaded) ckpt_pickled = cloudpickle.dumps(ckpt) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): - ckpt_loaded = cloudpickle.loads(ckpt_pickled) + ckpt_loaded = cloudpickle.loads(ckpt_pickled) assert vars(ckpt) == vars(ckpt_loaded) diff --git a/tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py b/tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py index 3ae7d6be4995e..c07400eaf8446 100644 --- a/tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py @@ -92,9 +92,10 @@ def test_trainer_save_checkpoint_storage_options(tmp_path, xla_available): io_mock.assert_called_with(ANY, instance_path, storage_options=None) checkpoint_mock = Mock() - with mock.patch.object(trainer.strategy, "save_checkpoint") as save_mock, mock.patch.object( - trainer._checkpoint_connector, "dump_checkpoint", return_value=checkpoint_mock - ) as dump_mock: + with ( + mock.patch.object(trainer.strategy, "save_checkpoint") as save_mock, + mock.patch.object(trainer._checkpoint_connector, "dump_checkpoint", return_value=checkpoint_mock) as dump_mock, + ): trainer.save_checkpoint(instance_path, True) dump_mock.assert_called_with(True) save_mock.assert_called_with(checkpoint_mock, instance_path, storage_options=None) diff --git a/tests/tests_pytorch/conftest.py b/tests/tests_pytorch/conftest.py index 78e81c7c5fa26..ea5207516cad1 100644 --- a/tests/tests_pytorch/conftest.py +++ b/tests/tests_pytorch/conftest.py @@ -15,10 +15,10 @@ import signal import sys import threading +from concurrent.futures.process import _ExecutorManagerThread from functools import partial from http.server import SimpleHTTPRequestHandler from pathlib import Path -from typing import List from unittest.mock import Mock import lightning.fabric @@ -35,9 +35,6 @@ from tests_pytorch import _PATH_DATASETS -if sys.version_info >= (3, 9): - from concurrent.futures.process import _ExecutorManagerThread - @pytest.fixture(scope="session") def datadir(): @@ -323,7 +320,7 @@ def leave_no_artifacts_behind(): assert not difference, f"Test left artifacts behind: {difference}" -def pytest_collection_modifyitems(items: List[pytest.Function], config: pytest.Config) -> None: +def pytest_collection_modifyitems(items: list[pytest.Function], config: pytest.Config) -> None: initial_size = len(items) conditions = [] filtered, skipped = 0, 0 diff --git a/tests/tests_pytorch/core/test_datamodules.py b/tests/tests_pytorch/core/test_datamodules.py index 65fccb691a33d..b3ccd88aae704 100644 --- a/tests/tests_pytorch/core/test_datamodules.py +++ b/tests/tests_pytorch/core/test_datamodules.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import pickle from argparse import Namespace from dataclasses import dataclass -from typing import Any, Dict +from typing import Any from unittest import mock from unittest.mock import Mock, PropertyMock, call @@ -22,7 +23,12 @@ import torch from lightning.pytorch import LightningDataModule, Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint -from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel +from lightning.pytorch.demos.boring_classes import ( + BoringDataModule, + BoringDataModuleNoLen, + BoringModel, + IterableBoringDataModule, +) from lightning.pytorch.profilers.simple import SimpleProfiler from lightning.pytorch.trainer.states import TrainerFn from lightning.pytorch.utilities import AttributeDict @@ -187,10 +193,10 @@ def validation_step(self, batch, batch_idx): return out class CustomBoringDataModule(BoringDataModule): - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return {"my": "state_dict"} - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: self.my_state_dict = state_dict dm = CustomBoringDataModule() @@ -510,3 +516,107 @@ def prepare_data(self): durations = profiler.recorded_durations[key] assert len(durations) == 1 assert durations[0] > 0 + + +def test_datamodule_string_not_available(): + dm = BoringDataModule() + + expected_output = ( + f"{{Train dataloader: None}}{os.linesep}" + f"{{Validation dataloader: None}}{os.linesep}" + f"{{Test dataloader: None}}{os.linesep}" + f"{{Predict dataloader: None}}" + ) + out = str(dm) + + assert out == expected_output + + +def test_datamodule_string_fit_setup(): + dm = BoringDataModule() + dm.setup(stage="fit") + + expected_output = ( + f"{{Train dataloader: size=64}}{os.linesep}" + f"{{Validation dataloader: size=64}}{os.linesep}" + f"{{Test dataloader: None}}{os.linesep}" + f"{{Predict dataloader: None}}" + ) + output = str(dm) + + assert expected_output == output + + +def test_datamodule_string_validation_setup(): + dm = BoringDataModule() + dm.setup(stage="validate") + + expected_output = ( + f"{{Train dataloader: None}}{os.linesep}" + f"{{Validation dataloader: size=64}}{os.linesep}" + f"{{Test dataloader: None}}{os.linesep}" + f"{{Predict dataloader: None}}" + ) + output = str(dm) + + assert expected_output == output + + +def test_datamodule_string_test_setup(): + dm = BoringDataModule() + dm.setup(stage="test") + + expected_output = ( + f"{{Train dataloader: None}}{os.linesep}" + f"{{Validation dataloader: None}}{os.linesep}" + f"{{Test dataloader: size=64}}{os.linesep}" + f"{{Predict dataloader: None}}" + ) + output = str(dm) + + assert expected_output == output + + +def test_datamodule_string_predict_setup(): + dm = BoringDataModule() + dm.setup(stage="predict") + + expected_output = ( + f"{{Train dataloader: None}}{os.linesep}" + f"{{Validation dataloader: None}}{os.linesep}" + f"{{Test dataloader: None}}{os.linesep}" + f"{{Predict dataloader: size=64}}" + ) + output = str(dm) + + assert expected_output == output + + +def test_datamodule_string_no_len(): + dm = BoringDataModuleNoLen() + dm.setup("fit") + + expected_output = ( + f"{{Train dataloader: size=NA}}{os.linesep}" + f"{{Validation dataloader: size=NA}}{os.linesep}" + f"{{Test dataloader: None}}{os.linesep}" + f"{{Predict dataloader: None}}" + ) + output = str(dm) + + assert output == expected_output + + +def test_datamodule_string_iterable(): + dm = IterableBoringDataModule() + dm.setup("fit") + + expected_output = ( + f"{{Train dataloader: 1. size=16 ; 2. size=NA}}{os.linesep}" + f"{{Validation dataloader: 1. size=32 ; 2. size=NA}}{os.linesep}" + f"{{Test dataloader: None}}{os.linesep}" + f"{{Predict dataloader: None}}" + ) + output = str(dm) + + assert output == expected_output diff --git a/tests/tests_pytorch/core/test_lightning_optimizer.py b/tests/tests_pytorch/core/test_lightning_optimizer.py index 8ab6eca907ce6..b25b7ae648a3a 100644 --- a/tests/tests_pytorch/core/test_lightning_optimizer.py +++ b/tests/tests_pytorch/core/test_lightning_optimizer.py @@ -110,9 +110,10 @@ def configure_optimizers(self): default_root_dir=tmp_path, limit_train_batches=8, limit_val_batches=1, max_epochs=1, enable_model_summary=False ) - with patch.multiple(torch.optim.SGD, zero_grad=DEFAULT, step=DEFAULT) as sgd, patch.multiple( - torch.optim.Adam, zero_grad=DEFAULT, step=DEFAULT - ) as adam: + with ( + patch.multiple(torch.optim.SGD, zero_grad=DEFAULT, step=DEFAULT) as sgd, + patch.multiple(torch.optim.Adam, zero_grad=DEFAULT, step=DEFAULT) as adam, + ): trainer.fit(model) assert sgd["step"].call_count == 4 diff --git a/tests/tests_pytorch/core/test_metric_result_integration.py b/tests/tests_pytorch/core/test_metric_result_integration.py index 9818f9807ae6d..dcb3f71c7499c 100644 --- a/tests/tests_pytorch/core/test_metric_result_integration.py +++ b/tests/tests_pytorch/core/test_metric_result_integration.py @@ -19,7 +19,6 @@ import lightning.pytorch as pl import pytest import torch -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch import Trainer from lightning.pytorch.callbacks import OnExceptionCheckpoint @@ -254,8 +253,7 @@ def lightning_log(fx, *args, **kwargs): } # make sure can be pickled - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): - pickle.loads(pickle.dumps(result)) + pickle.loads(pickle.dumps(result)) # make sure can be torch.loaded filepath = str(tmp_path / "result") torch.save(result, filepath) @@ -627,8 +625,9 @@ def test_logger_sync_dist(distributed_env, log_val): else nullcontext() ) - with warning_ctx( - PossibleUserWarning, match=r"recommended to use `self.log\('bar', ..., sync_dist=True\)`" - ), patch_ctx: + with ( + warning_ctx(PossibleUserWarning, match=r"recommended to use `self.log\('bar', ..., sync_dist=True\)`"), + patch_ctx, + ): value = _ResultCollection._get_cache(result_metric, on_step=False) assert value == 0.5 diff --git a/tests/tests_pytorch/helpers/datasets.py b/tests/tests_pytorch/helpers/datasets.py index 9b1d4ec7353cb..014fb374e5d5e 100644 --- a/tests/tests_pytorch/helpers/datasets.py +++ b/tests/tests_pytorch/helpers/datasets.py @@ -16,7 +16,8 @@ import random import time import urllib.request -from typing import Optional, Sequence, Tuple +from collections.abc import Sequence +from typing import Optional import torch from torch import Tensor @@ -63,7 +64,7 @@ def __init__( data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME self.data, self.targets = self._try_load(os.path.join(self.cached_folder_path, data_file)) - def __getitem__(self, idx: int) -> Tuple[Tensor, int]: + def __getitem__(self, idx: int) -> tuple[Tensor, int]: img = self.data[idx].float().unsqueeze(0) target = int(self.targets[idx]) diff --git a/tests/tests_pytorch/helpers/test_datasets.py b/tests/tests_pytorch/helpers/test_datasets.py index ddc20c29e62e8..d71ed118fe835 100644 --- a/tests/tests_pytorch/helpers/test_datasets.py +++ b/tests/tests_pytorch/helpers/test_datasets.py @@ -12,12 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import pickle -from contextlib import nullcontext import cloudpickle import pytest import torch -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 from tests_pytorch import _PATH_DATASETS from tests_pytorch.helpers.datasets import MNIST, AverageDataset, TrialMNIST @@ -44,9 +42,7 @@ def test_pickling_dataset_mnist(dataset_cls, args): mnist = dataset_cls(**args) mnist_pickled = pickle.dumps(mnist) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): - pickle.loads(mnist_pickled) + pickle.loads(mnist_pickled) mnist_pickled = cloudpickle.dumps(mnist) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): - cloudpickle.loads(mnist_pickled) + cloudpickle.loads(mnist_pickled) diff --git a/tests/tests_pytorch/loggers/conftest.py b/tests/tests_pytorch/loggers/conftest.py index 7cc5cc94fe8cc..3912a4dea0db9 100644 --- a/tests/tests_pytorch/loggers/conftest.py +++ b/tests/tests_pytorch/loggers/conftest.py @@ -94,18 +94,20 @@ def comet_mock(monkeypatch): comet = ModuleType("comet_ml") monkeypatch.setitem(sys.modules, "comet_ml", comet) - comet.Experiment = Mock() - comet.ExistingExperiment = Mock() - comet.OfflineExperiment = Mock() - comet.API = Mock() + # to support dunder methods calling we will create a special mock + comet_experiment = MagicMock(name="CommonExperiment") + setattr(comet_experiment, "__internal_api__set_model_graph__", MagicMock()) + setattr(comet_experiment, "__internal_api__log_metrics__", MagicMock()) + setattr(comet_experiment, "__internal_api__log_parameters__", MagicMock()) + + comet.Experiment = MagicMock(name="Experiment", return_value=comet_experiment) + comet.ExistingExperiment = MagicMock(name="ExistingExperiment", return_value=comet_experiment) + comet.OfflineExperiment = MagicMock(name="OfflineExperiment", return_value=comet_experiment) + + comet.ExperimentConfig = Mock() + comet.start = Mock(name="comet_ml.start", return_value=comet.Experiment()) comet.config = Mock() - comet_api = ModuleType("api") - comet_api.API = Mock() - monkeypatch.setitem(sys.modules, "comet_ml.api", comet_api) - - comet.api = comet_api - monkeypatch.setattr("lightning.pytorch.loggers.comet._COMET_AVAILABLE", True) return comet diff --git a/tests/tests_pytorch/loggers/test_all.py b/tests/tests_pytorch/loggers/test_all.py index 503e49fe6cdad..d6de763b8f74e 100644 --- a/tests/tests_pytorch/loggers/test_all.py +++ b/tests/tests_pytorch/loggers/test_all.py @@ -14,13 +14,11 @@ import inspect import os import pickle -from contextlib import nullcontext from unittest import mock from unittest.mock import ANY, Mock import pytest import torch -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 from lightning.pytorch import Callback, Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers import ( @@ -107,7 +105,9 @@ def log_metrics(self, metrics, step): if logger_class == CometLogger: logger.experiment.id = "foo" - logger.experiment.project_name = "bar" + logger._comet_config.offline_directory = None + logger._project_name = "bar" + logger.experiment.get_key.return_value = "SOME_KEY" if logger_class == NeptuneLogger: logger._retrieve_run_data = Mock() @@ -163,7 +163,7 @@ def test_loggers_pickle_all(tmp_path, monkeypatch, logger_class): pytest.xfail(f"pickle test requires {logger_class.__class__} dependencies to be installed.") -def _test_loggers_pickle(tmp_path, monkeypatch, logger_class): +def _test_loggers_pickle(tmp_path, monkeypatch, logger_class: Logger): """Verify that pickling trainer with logger works.""" _patch_comet_atexit(monkeypatch) @@ -184,8 +184,7 @@ def _test_loggers_pickle(tmp_path, monkeypatch, logger_class): trainer = Trainer(max_epochs=1, logger=logger) pkl_bytes = pickle.dumps(trainer) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): - trainer2 = pickle.loads(pkl_bytes) + trainer2 = pickle.loads(pkl_bytes) trainer2.logger.log_metrics({"acc": 1.0}) # make sure we restored properly @@ -295,7 +294,9 @@ def test_logger_with_prefix_all(mlflow_mock, wandb_mock, comet_mock, neptune_moc _patch_comet_atexit(monkeypatch) logger = _instantiate_logger(CometLogger, save_dir=tmp_path, prefix=prefix) logger.log_metrics({"test": 1.0}, step=0) - logger.experiment.log_metrics.assert_called_once_with({"tmp-test": 1.0}, epoch=None, step=0) + logger.experiment.__internal_api__log_metrics__.assert_called_once_with( + {"test": 1.0}, epoch=None, step=0, prefix=prefix, framework="pytorch-lightning" + ) # MLflow Metric = mlflow_mock.entities.Metric diff --git a/tests/tests_pytorch/loggers/test_comet.py b/tests/tests_pytorch/loggers/test_comet.py index e467c63543ede..34c24211d13c9 100644 --- a/tests/tests_pytorch/loggers/test_comet.py +++ b/tests/tests_pytorch/loggers/test_comet.py @@ -13,15 +13,13 @@ # limitations under the License. import os from unittest import mock -from unittest.mock import DEFAULT, Mock, patch +from unittest.mock import Mock, call -import pytest -from lightning.pytorch import Trainer -from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers import CometLogger -from lightning.pytorch.utilities.exceptions import MisconfigurationException from torch import tensor +FRAMEWORK_NAME = "pytorch-lightning" + def _patch_comet_atexit(monkeypatch): """Prevent comet logger from trying to print at exit, since pytest's stdout/stderr redirection breaks it.""" @@ -33,195 +31,163 @@ def _patch_comet_atexit(monkeypatch): @mock.patch.dict(os.environ, {}) def test_comet_logger_online(comet_mock): """Test comet online with mocks.""" - # Test api_key given - comet_experiment = comet_mock.Experiment - logger = CometLogger(api_key="key", workspace="dummy-test", project_name="general") - _ = logger.experiment - comet_experiment.assert_called_once_with(api_key="key", workspace="dummy-test", project_name="general") - - # Test both given - comet_experiment.reset_mock() - logger = CometLogger(save_dir="test", api_key="key", workspace="dummy-test", project_name="general") - _ = logger.experiment - comet_experiment.assert_called_once_with(api_key="key", workspace="dummy-test", project_name="general") - - # Test already exists - comet_existing = comet_mock.ExistingExperiment - logger = CometLogger( - experiment_key="test", - experiment_name="experiment", + + comet_start = comet_mock.start + + # Test api_key given with old param "project_name" + _logger = CometLogger(api_key="key", workspace="dummy-test", project_name="general") + comet_start.assert_called_once_with( api_key="key", workspace="dummy-test", - project_name="general", + project="general", + experiment_key=None, + mode=None, + online=None, + experiment_config=comet_mock.ExperimentConfig(), ) - _ = logger.experiment - comet_existing.assert_called_once_with( - api_key="key", workspace="dummy-test", project_name="general", previous_experiment="test" + + # Test online given + comet_start.reset_mock() + _logger = CometLogger(save_dir="test", api_key="key", workspace="dummy-test", project_name="general", online=True) + comet_start.assert_called_once_with( + api_key="key", + workspace="dummy-test", + project="general", + experiment_key=None, + mode=None, + online=True, + experiment_config=comet_mock.ExperimentConfig(), ) - comet_existing().set_name.assert_called_once_with("experiment") - # API experiment - api = comet_mock.api.API - CometLogger(api_key="key", workspace="dummy-test", project_name="general", rest_api_key="rest") - api.assert_called_once_with("rest") + # Test experiment_key given + comet_start.reset_mock() + _logger = CometLogger( + experiment_key="test_key", + api_key="key", + project="general", + ) + comet_start.assert_called_once_with( + api_key="key", + workspace=None, + project="general", + experiment_key="test_key", + mode=None, + online=None, + experiment_config=comet_mock.ExperimentConfig(), + ) @mock.patch.dict(os.environ, {}) -def test_comet_experiment_resets_if_not_alive(comet_mock): - """Test that the CometLogger creates a new experiment if the old one is not alive anymore.""" +def test_comet_experiment_is_still_alive_after_training_complete(comet_mock): + """Test that the CometLogger will not end an experiment after training is complete.""" + logger = CometLogger() - assert logger._experiment is None - alive_experiment = Mock(alive=True) - logger._experiment = alive_experiment - assert logger.experiment is alive_experiment + assert logger.experiment is not None - unalive_experiment = Mock(alive=False) - logger._experiment = unalive_experiment - assert logger.experiment is not unalive_experiment + logger._experiment = Mock() + logger.finalize("ended") + # Assert that data was saved to comet.com + logger._experiment.flush.assert_called_once() -@mock.patch.dict(os.environ, {}) -def test_comet_logger_no_api_key_given(comet_mock): - """Test that CometLogger fails to initialize if both api key and save_dir are missing.""" - with pytest.raises(MisconfigurationException, match="requires either api_key or save_dir"): - comet_mock.config.get_api_key.return_value = None - CometLogger(workspace="dummy-test", project_name="general") + # Assert that was not ended + logger._experiment.end.assert_not_called() @mock.patch.dict(os.environ, {}) def test_comet_logger_experiment_name(comet_mock): """Test that Comet Logger experiment name works correctly.""" - api_key = "key" - experiment_name = "My Name" + api_key = "api_key" + experiment_name = "My Experiment Name" + + comet_start = comet_mock.start - # Test api_key given - comet_experiment = comet_mock.Experiment + # here we use old style arg "experiment_name" (new one is "name") logger = CometLogger(api_key=api_key, experiment_name=experiment_name) - assert logger._experiment is None + comet_start.assert_called_once_with( + api_key=api_key, + workspace=None, + project=None, + experiment_key=None, + mode=None, + online=None, + experiment_config=comet_mock.ExperimentConfig(), + ) + # check that we saved "experiment name" in kwargs as new "name" arg + assert logger._kwargs["name"] == experiment_name + assert "experiment_name" not in logger._kwargs - _ = logger.experiment - comet_experiment.assert_called_once_with(api_key=api_key, project_name=None) - comet_experiment().set_name.assert_called_once_with(experiment_name) + # check that "experiment name" was passed to experiment config correctly + assert call(experiment_name=experiment_name) not in comet_mock.ExperimentConfig.call_args_list + assert call(name=experiment_name) in comet_mock.ExperimentConfig.call_args_list @mock.patch.dict(os.environ, {}) -def test_comet_logger_manual_experiment_key(comet_mock): - """Test that Comet Logger respects manually set COMET_EXPERIMENT_KEY.""" +def test_comet_version(comet_mock): + """Test that CometLogger.version returns an Experiment key.""" api_key = "key" - experiment_key = "96346da91469407a85641afe5766b554" - - instantiation_environ = {} - - def save_os_environ(*args, **kwargs): - nonlocal instantiation_environ - instantiation_environ = os.environ.copy() - - return DEFAULT - - comet_experiment = comet_mock.Experiment - comet_experiment.side_effect = save_os_environ - - # Test api_key given - with patch.dict(os.environ, {"COMET_EXPERIMENT_KEY": experiment_key}): - logger = CometLogger(api_key=api_key) - assert logger.version == experiment_key - assert logger._experiment is None + experiment_name = "My Name" - _ = logger.experiment - comet_experiment.assert_called_once_with(api_key=api_key, project_name=None) + logger = CometLogger(api_key=api_key, name=experiment_name) + assert logger._experiment is not None + _ = logger.version - assert instantiation_environ["COMET_EXPERIMENT_KEY"] == experiment_key + logger._experiment.get_key.assert_called() @mock.patch.dict(os.environ, {}) -def test_comet_logger_dirs_creation(comet_mock, tmp_path, monkeypatch): - """Test that the logger creates the folders and files in the right place.""" +def test_comet_epoch_logging(comet_mock, tmp_path, monkeypatch): + """Test that CometLogger removes the epoch key from the metrics dict and passes it as argument.""" _patch_comet_atexit(monkeypatch) - comet_experiment = comet_mock.OfflineExperiment - - comet_mock.config.get_api_key.return_value = None - comet_mock.generate_guid = Mock() - comet_mock.generate_guid.return_value = "4321" - logger = CometLogger(project_name="test", save_dir=str(tmp_path)) - assert not os.listdir(tmp_path) - assert logger.mode == "offline" - assert logger.save_dir == str(tmp_path) - assert logger.name == "test" - assert logger.version == "4321" - - _ = logger.experiment - comet_experiment.assert_called_once_with(offline_directory=str(tmp_path), project_name="test") - - # mock return values of experiment - logger.experiment.id = "1" - logger.experiment.project_name = "test" - - model = BoringModel() - trainer = Trainer( - default_root_dir=tmp_path, logger=logger, max_epochs=1, limit_train_batches=3, limit_val_batches=3 + logger.log_metrics({"test": 1, "epoch": 1}, step=123) + logger.experiment.__internal_api__log_metrics__.assert_called_once_with( + {"test": 1}, + epoch=1, + step=123, + prefix=logger._prefix, + framework="pytorch-lightning", ) - assert trainer.log_dir == logger.save_dir - trainer.fit(model) - - assert trainer.checkpoint_callback.dirpath == str(tmp_path / "test" / "1" / "checkpoints") - assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {"epoch=0-step=3.ckpt"} - assert trainer.log_dir == logger.save_dir @mock.patch.dict(os.environ, {}) -def test_comet_name_default(comet_mock): - """Test that CometLogger.name don't create an Experiment and returns a default value.""" - api_key = "key" - logger = CometLogger(api_key=api_key) - assert logger._experiment is None - assert logger.name == "comet-default" - assert logger._experiment is None - +def test_comet_log_hyperparams(comet_mock, tmp_path, monkeypatch): + """Test that CometLogger.log_hyperparams calls internal API method.""" + _patch_comet_atexit(monkeypatch) -@mock.patch.dict(os.environ, {}) -def test_comet_name_project_name(comet_mock): - """Test that CometLogger.name does not create an Experiment and returns project name if passed.""" - api_key = "key" - project_name = "My Project Name" - logger = CometLogger(api_key=api_key, project_name=project_name) - assert logger._experiment is None - assert logger.name == project_name - assert logger._experiment is None + logger = CometLogger(project_name="test") + hyperparams = { + "batch_size": 256, + "config": { + "SLURM Job ID": "22334455", + "RGB slurm jobID": "12345678", + "autoencoder_model": False, + }, + } + logger.log_hyperparams(hyperparams) + + logger.experiment.__internal_api__log_parameters__.assert_called_once_with( + parameters=hyperparams, + framework=FRAMEWORK_NAME, + flatten_nested=True, + source="manual", + ) @mock.patch.dict(os.environ, {}) -def test_comet_version_without_experiment(comet_mock): - """Test that CometLogger.version does not create an Experiment.""" - api_key = "key" - experiment_name = "My Name" - comet_mock.generate_guid = Mock() - comet_mock.generate_guid.return_value = "1234" - - logger = CometLogger(api_key=api_key, experiment_name=experiment_name) - assert logger._experiment is None - - first_version = logger.version - assert first_version is not None - assert logger.version == first_version - assert logger._experiment is None - - _ = logger.experiment - - logger.reset_experiment() +def test_comet_log_graph(comet_mock, tmp_path, monkeypatch): + """Test that CometLogger.log_hyperparams calls internal API method.""" + _patch_comet_atexit(monkeypatch) - second_version = logger.version == "1234" - assert second_version is not None - assert second_version != first_version + logger = CometLogger(project_name="test") + model = Mock() + logger.log_graph(model=model) -@mock.patch.dict(os.environ, {}) -def test_comet_epoch_logging(comet_mock, tmp_path, monkeypatch): - """Test that CometLogger removes the epoch key from the metrics dict and passes it as argument.""" - _patch_comet_atexit(monkeypatch) - logger = CometLogger(project_name="test", save_dir=str(tmp_path)) - logger.log_metrics({"test": 1, "epoch": 1}, step=123) - logger.experiment.log_metrics.assert_called_once_with({"test": 1}, epoch=1, step=123) + logger.experiment.__internal_api__set_model_graph__.assert_called_once_with( + graph=model, + framework="pytorch-lightning", + ) @mock.patch.dict(os.environ, {}) diff --git a/tests/tests_pytorch/loggers/test_logger.py b/tests/tests_pytorch/loggers/test_logger.py index 7b384890f6148..dcdd504fd4660 100644 --- a/tests/tests_pytorch/loggers/test_logger.py +++ b/tests/tests_pytorch/loggers/test_logger.py @@ -13,15 +13,13 @@ # limitations under the License. import pickle from argparse import Namespace -from contextlib import nullcontext from copy import deepcopy -from typing import Any, Dict, Optional +from typing import Any, Optional from unittest.mock import patch import numpy as np import pytest import torch -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 from lightning.fabric.utilities.logger import _convert_params, _sanitize_params from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel @@ -124,8 +122,7 @@ def test_multiple_loggers_pickle(tmp_path): trainer = Trainer(logger=[logger1, logger2]) pkl_bytes = pickle.dumps(trainer) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): - trainer2 = pickle.loads(pkl_bytes) + trainer2 = pickle.loads(pkl_bytes) for logger in trainer2.loggers: logger.log_metrics({"acc": 1.0}, 0) @@ -255,12 +252,12 @@ def __init__(self, param_one, param_two): @patch("lightning.pytorch.loggers.tensorboard.TensorBoardLogger.log_hyperparams") def test_log_hyperparams_key_collision(_, tmp_path): class TestModel(BoringModel): - def __init__(self, hparams: Dict[str, Any]) -> None: + def __init__(self, hparams: dict[str, Any]) -> None: super().__init__() self.save_hyperparameters(hparams) class TestDataModule(BoringDataModule): - def __init__(self, hparams: Dict[str, Any]) -> None: + def __init__(self, hparams: dict[str, Any]) -> None: super().__init__() self.save_hyperparameters(hparams) diff --git a/tests/tests_pytorch/loggers/test_wandb.py b/tests/tests_pytorch/loggers/test_wandb.py index e9195f628348b..a8e70bfb6589d 100644 --- a/tests/tests_pytorch/loggers/test_wandb.py +++ b/tests/tests_pytorch/loggers/test_wandb.py @@ -13,13 +13,11 @@ # limitations under the License. import os import pickle -from contextlib import nullcontext from pathlib import Path from unittest import mock import pytest import yaml -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.cli import LightningCLI @@ -162,8 +160,7 @@ def name(self): assert trainer.logger.experiment, "missing experiment" assert trainer.log_dir == logger.save_dir pkl_bytes = pickle.dumps(trainer) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): - trainer2 = pickle.loads(pkl_bytes) + trainer2 = pickle.loads(pkl_bytes) assert os.environ["WANDB_MODE"] == "dryrun" assert trainer2.logger.__class__.__name__ == WandbLogger.__name__ diff --git a/tests/tests_pytorch/loops/optimization/test_automatic_loop.py b/tests/tests_pytorch/loops/optimization/test_automatic_loop.py index 0ea6290586f55..2fb04d0d9d8d1 100644 --- a/tests/tests_pytorch/loops/optimization/test_automatic_loop.py +++ b/tests/tests_pytorch/loops/optimization/test_automatic_loop.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Iterator, Mapping from contextlib import nullcontext -from typing import Dict, Generic, Iterator, Mapping, TypeVar +from typing import Generic, TypeVar import pytest import torch @@ -49,8 +50,8 @@ def test_closure_result_apply_accumulation(): class OutputMapping(Generic[T], Mapping[str, T]): - def __init__(self, d: Dict[str, T]) -> None: - self.d: Dict[str, T] = d + def __init__(self, d: dict[str, T]) -> None: + self.d: dict[str, T] = d def __iter__(self) -> Iterator[str]: return iter(self.d) diff --git a/tests/tests_pytorch/loops/test_fetchers.py b/tests/tests_pytorch/loops/test_fetchers.py index 763a6ded14447..75b25e3d98fd8 100644 --- a/tests/tests_pytorch/loops/test_fetchers.py +++ b/tests/tests_pytorch/loops/test_fetchers.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import Counter -from typing import Any, Iterator +from collections.abc import Iterator +from typing import Any import pytest import torch diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index ff317cd2e18ba..1820ca3568173 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from collections.abc import Iterator from copy import deepcopy from dataclasses import dataclass -from typing import Dict, Iterator +from typing import Any from unittest.mock import ANY, Mock import pytest @@ -25,6 +26,7 @@ from lightning.pytorch.loops import _Loop from lightning.pytorch.loops.progress import _BaseProgress from lightning.pytorch.utilities import CombinedLoader +from lightning.pytorch.utilities.types import STEP_OUTPUT from torch.utils.data.dataloader import DataLoader, _MultiProcessingDataLoaderIter from tests_pytorch.helpers.runif import RunIf @@ -86,10 +88,10 @@ def advance(self) -> None: self.outputs.append(value) - def state_dict(self) -> Dict: + def state_dict(self) -> dict: return {"iteration_count": self.iteration_count, "outputs": self.outputs} - def load_state_dict(self, state_dict: Dict) -> None: + def load_state_dict(self, state_dict: dict) -> None: self.iteration_count = state_dict["iteration_count"] self.outputs = state_dict["outputs"] @@ -139,10 +141,10 @@ def advance(self) -> None: return loop.run() - def on_save_checkpoint(self) -> Dict: + def on_save_checkpoint(self) -> dict: return {"a": self.a} - def on_load_checkpoint(self, state_dict: Dict) -> None: + def on_load_checkpoint(self, state_dict: dict) -> None: self.a = state_dict["a"] trainer = Trainer() @@ -396,12 +398,13 @@ def training_step(self, batch, batch_idx): assert state_dict == checkpoint["loops"]["fit_loop"] trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"]) - # test resetting manually, we expect all `ready` counters to be reset to `completed` + # test resetting manually, we expect the `ready` counter for batch to be reset to `completed` + # but the `ready` counter for epoch to not be reset, since we are still mid epoch trainer.fit_loop.reset() trainer.fit_loop.epoch_loop.reset() epoch_progress = trainer.fit_loop.epoch_progress - assert epoch_progress.current.ready == stop_epoch + assert epoch_progress.current.ready == stop_epoch + 1 assert epoch_progress.current.completed == stop_epoch batch_progress = trainer.fit_loop.epoch_loop.batch_progress @@ -417,7 +420,7 @@ def training_step(self, batch, batch_idx): state_dict = trainer.fit_loop.state_dict() assert state_dict != checkpoint["loops"]["fit_loop"] assert state_dict["epoch_progress"]["total"]["started"] == stop_epoch + 1 - assert state_dict["epoch_progress"]["current"]["started"] == stop_epoch + assert state_dict["epoch_progress"]["current"]["started"] == stop_epoch + 1 def test_loop_state_on_complete_run(tmp_path): @@ -557,23 +560,38 @@ def test_fit_loop_reset(tmp_path): # we load exactly what was saved - no reset yet fit_loop.load_state_dict(mid_epoch_ckpt["loops"]["fit_loop"]) - # resetting from a mid-of-epoch checkpoint SHOULD NOT reset the current counters to 0 + + assert fit_loop.restarting + assert fit_loop.epoch_progress.total.ready == 1 + assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint was saved mid epoch + assert fit_loop.epoch_progress.current.ready == 1 + assert fit_loop.epoch_progress.current.completed == 0 + + assert epoch_loop.batch_progress.total.ready == 2 + assert epoch_loop.batch_progress.total.processed == 2 + assert epoch_loop.batch_progress.total.completed == 1 # the checkpoint was saved on train_batch_end + assert epoch_loop.batch_progress.current.ready == 2 # currents get set to the completed value + assert epoch_loop.batch_progress.current.processed == 2 + assert epoch_loop.batch_progress.current.completed == 1 + fit_loop.reset() epoch_loop.reset() + # resetting from a mid-of-epoch checkpoint SHOULD NOT reset the current counters to 0 assert fit_loop.restarting assert fit_loop.epoch_progress.total.ready == 1 assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint was saved mid epoch - assert fit_loop.epoch_progress.current.ready == 0 + assert fit_loop.epoch_progress.current.ready == 1 assert fit_loop.epoch_progress.current.completed == 0 + # however it should increment completed batch progress, since it was saved immediately prior assert epoch_loop.restarting assert epoch_loop.batch_progress.total.ready == 2 assert epoch_loop.batch_progress.total.processed == 2 - assert epoch_loop.batch_progress.total.completed == 1 # the checkpoint was saved on train_batch_end - assert epoch_loop.batch_progress.current.ready == 1 # currents get set to the completed value - assert epoch_loop.batch_progress.current.processed == 1 - assert epoch_loop.batch_progress.current.completed == 1 + assert epoch_loop.batch_progress.total.completed == 2 + assert epoch_loop.batch_progress.current.ready == 2 + assert epoch_loop.batch_progress.current.processed == 2 + assert epoch_loop.batch_progress.current.completed == 2 assert optimizer_loop.restarting @@ -587,23 +605,326 @@ def test_fit_loop_reset(tmp_path): # we load exactly what was saved - no reset yet fit_loop.load_state_dict(end_of_epoch_ckpt["loops"]["fit_loop"]) + + assert fit_loop.restarting + assert fit_loop.epoch_progress.total.ready == 1 + assert fit_loop.epoch_progress.total.completed == 0 + assert fit_loop.epoch_progress.current.ready == 1 + assert fit_loop.epoch_progress.current.completed == 0 + # resetting from a end-of-epoch checkpoint SHOULD reset the current counters to 0 fit_loop.reset() epoch_loop.reset() + # resetting from a mid-of-epoch checkpoint SHOULD NOT reset the current counters to 0 + # since we are restarting at the end of epoch, we need to see `completed` being updated after reset assert fit_loop.restarting assert fit_loop.epoch_progress.total.ready == 1 - assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint saves before the epoch completes - assert fit_loop.epoch_progress.current.ready == 0 - assert fit_loop.epoch_progress.current.completed == 0 + assert fit_loop.epoch_progress.total.completed == 1 + assert fit_loop.epoch_progress.current.ready == 1 + assert fit_loop.epoch_progress.current.completed == 1 + # however it should increment completed batch progress, since it was saved immediately prior assert epoch_loop.restarting assert epoch_loop.batch_progress.total.ready == 4 assert epoch_loop.batch_progress.total.processed == 4 - assert epoch_loop.batch_progress.total.completed == 3 # the checkpoint was saved on train_batch_end - assert epoch_loop.batch_progress.current.ready == 3 # currents get set to the completed value - assert epoch_loop.batch_progress.current.processed == 3 - assert epoch_loop.batch_progress.current.completed == 3 + assert epoch_loop.batch_progress.total.completed == 4 + assert epoch_loop.batch_progress.current.ready == 0 + assert epoch_loop.batch_progress.current.processed == 0 + assert epoch_loop.batch_progress.current.completed == 0 + + +def compare_state_dicts(dict1, dict2): + def compare_leaves(d1, d2): + result = {} + all_keys = set(d1.keys()).union(d2.keys()) + + for key in all_keys: + val1 = d1.get(key, None) + val2 = d2.get(key, None) + + if isinstance(val1, dict) and isinstance(val2, dict): + res = compare_leaves(val1, val2) + if res: + result[key] = res + elif isinstance(val1, dict) or isinstance(val2, dict): + raise ValueError("dicts have different leaves") + elif isinstance(val1, torch.Tensor) and isinstance(val2, torch.Tensor): + diff = torch.norm(val1 - val2) + if diff > 1e-8: + result[key] = f"{diff} > 1e-8" + elif isinstance(val1, float) and isinstance(val2, float): + if abs(val1 - val2) > 1e-8: + result[key] = f"{val1} != {val2}" + elif val1 != val2: + result[key] = f"{val1} != {val2}" + return result + + return compare_leaves(dict1, dict2) + + +class RangeDataset(torch.utils.data.Dataset): + def __init__(self, size: int, length: int): + self.len = length + data = torch.arange(0, size) / size + self.data = data.unsqueeze(0).repeat(length, 1) + + def __getitem__(self, index: int) -> torch.Tensor: + return self.data[index] + + def __len__(self) -> int: + return self.len + + +class PredictableBoringModel(BoringModel): + def __init__(self) -> None: + super().__init__() + self.last_loss = float("inf") + + def train_dataloader(self) -> DataLoader: + return DataLoader(RangeDataset(32, 64)) + + def val_dataloader(self) -> DataLoader: + return DataLoader(RangeDataset(32, 64)) + + def test_dataloader(self) -> DataLoader: + return DataLoader(RangeDataset(32, 64)) + + def predict_dataloader(self) -> DataLoader: + return DataLoader(RangeDataset(32, 64)) + + def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT: + loss = self.step(batch) + self.last_loss = loss + return {"loss": loss} + + +def test_restart_parity(tmp_path): + model = PredictableBoringModel() + checkpoint_callback = ModelCheckpoint( + dirpath=tmp_path, + every_n_train_steps=2, + save_top_k=-1, + ) + trainer = Trainer( + default_root_dir=tmp_path, + limit_train_batches=4, + max_epochs=4, + callbacks=[checkpoint_callback], + logger=False, + enable_model_summary=False, + enable_progress_bar=False, + ) + trainer.fit(model) + loss = model.last_loss + + trainer = Trainer( + default_root_dir=tmp_path, + limit_train_batches=4, + max_epochs=4, + callbacks=[checkpoint_callback], + logger=False, + enable_model_summary=False, + enable_progress_bar=False, + ) + trainer.fit(model, ckpt_path=str(tmp_path / "epoch=0-step=2.ckpt")) + loss_v1 = model.last_loss + + assert abs(loss - loss_v1) < 1e-8 + + end_of_epoch_ckpt = torch.load(str(tmp_path / "epoch=0-step=4.ckpt"), weights_only=True) + end_of_epoch_ckpt_v1 = torch.load(str(tmp_path / "epoch=0-step=4-v1.ckpt"), weights_only=True) + + assert compare_state_dicts(end_of_epoch_ckpt["loops"], end_of_epoch_ckpt_v1["loops"]) == {} + assert compare_state_dicts(end_of_epoch_ckpt["lr_schedulers"][0], end_of_epoch_ckpt_v1["lr_schedulers"][0]) == {} + assert end_of_epoch_ckpt["epoch"] == end_of_epoch_ckpt_v1["epoch"] + assert end_of_epoch_ckpt["global_step"] == end_of_epoch_ckpt_v1["global_step"] + assert compare_state_dicts(end_of_epoch_ckpt["state_dict"], end_of_epoch_ckpt_v1["state_dict"]) == {} + + mid_epoch_ckpt = torch.load(str(tmp_path / "epoch=1-step=6.ckpt"), weights_only=True) + mid_epoch_ckpt_v1 = torch.load(str(tmp_path / "epoch=1-step=6-v1.ckpt"), weights_only=True) + + assert compare_state_dicts(mid_epoch_ckpt["loops"], mid_epoch_ckpt_v1["loops"]) == {} + assert compare_state_dicts(mid_epoch_ckpt["lr_schedulers"][0], mid_epoch_ckpt_v1["lr_schedulers"][0]) == {} + assert mid_epoch_ckpt["epoch"] == mid_epoch_ckpt_v1["epoch"] + assert mid_epoch_ckpt["global_step"] == mid_epoch_ckpt_v1["global_step"] + assert compare_state_dicts(mid_epoch_ckpt["state_dict"], mid_epoch_ckpt_v1["state_dict"]) == {} + + end_of_epoch_ckpt = torch.load(str(tmp_path / "epoch=1-step=8.ckpt"), weights_only=True) + end_of_epoch_ckpt_v1 = torch.load(str(tmp_path / "epoch=1-step=8-v1.ckpt"), weights_only=True) + + assert compare_state_dicts(end_of_epoch_ckpt["loops"], end_of_epoch_ckpt_v1["loops"]) == {} + assert compare_state_dicts(end_of_epoch_ckpt["lr_schedulers"][0], end_of_epoch_ckpt_v1["lr_schedulers"][0]) == {} + assert end_of_epoch_ckpt["epoch"] == end_of_epoch_ckpt_v1["epoch"] + assert end_of_epoch_ckpt["global_step"] == end_of_epoch_ckpt_v1["global_step"] + assert compare_state_dicts(end_of_epoch_ckpt["state_dict"], end_of_epoch_ckpt_v1["state_dict"]) == {} + + +def test_restart_with_val_parity(tmp_path): + model = PredictableBoringModel() + checkpoint_callback = ModelCheckpoint( + dirpath=tmp_path, + every_n_train_steps=2, + save_top_k=-1, + ) + trainer = Trainer( + default_root_dir=tmp_path, + limit_train_batches=4, + max_epochs=4, + callbacks=[checkpoint_callback], + logger=False, + enable_model_summary=False, + enable_progress_bar=False, + limit_val_batches=4, + val_check_interval=2, + ) + trainer.fit(model) + loss = model.last_loss + + trainer = Trainer( + default_root_dir=tmp_path, + limit_train_batches=4, + max_epochs=4, + callbacks=[checkpoint_callback], + logger=False, + enable_model_summary=False, + enable_progress_bar=False, + limit_val_batches=4, + val_check_interval=2, + ) + trainer.fit(model, ckpt_path=str(tmp_path / "epoch=0-step=2.ckpt")) + loss_v1 = model.last_loss + + assert abs(loss - loss_v1) < 1e-8 + + end_of_epoch_ckpt = torch.load(str(tmp_path / "epoch=0-step=4.ckpt"), weights_only=True) + end_of_epoch_ckpt_v1 = torch.load(str(tmp_path / "epoch=0-step=4-v1.ckpt"), weights_only=True) + + assert compare_state_dicts(end_of_epoch_ckpt["loops"], end_of_epoch_ckpt_v1["loops"]) == {} + assert compare_state_dicts(end_of_epoch_ckpt["lr_schedulers"][0], end_of_epoch_ckpt_v1["lr_schedulers"][0]) == {} + assert end_of_epoch_ckpt["epoch"] == end_of_epoch_ckpt_v1["epoch"] + assert end_of_epoch_ckpt["global_step"] == end_of_epoch_ckpt_v1["global_step"] + assert compare_state_dicts(end_of_epoch_ckpt["state_dict"], end_of_epoch_ckpt_v1["state_dict"]) == {} + + mid_epoch_ckpt = torch.load(str(tmp_path / "epoch=1-step=6.ckpt"), weights_only=True) + mid_epoch_ckpt_v1 = torch.load(str(tmp_path / "epoch=1-step=6-v1.ckpt"), weights_only=True) + + assert compare_state_dicts(mid_epoch_ckpt["loops"], mid_epoch_ckpt_v1["loops"]) == {} + assert compare_state_dicts(mid_epoch_ckpt["lr_schedulers"][0], mid_epoch_ckpt_v1["lr_schedulers"][0]) == {} + assert mid_epoch_ckpt["epoch"] == mid_epoch_ckpt_v1["epoch"] + assert mid_epoch_ckpt["global_step"] == mid_epoch_ckpt_v1["global_step"] + assert compare_state_dicts(mid_epoch_ckpt["state_dict"], mid_epoch_ckpt_v1["state_dict"]) == {} + + end_of_epoch_ckpt = torch.load(str(tmp_path / "epoch=1-step=8.ckpt"), weights_only=True) + end_of_epoch_ckpt_v1 = torch.load(str(tmp_path / "epoch=1-step=8-v1.ckpt"), weights_only=True) + + assert compare_state_dicts(end_of_epoch_ckpt["loops"], end_of_epoch_ckpt_v1["loops"]) == {} + assert compare_state_dicts(end_of_epoch_ckpt["lr_schedulers"][0], end_of_epoch_ckpt_v1["lr_schedulers"][0]) == {} + assert end_of_epoch_ckpt["epoch"] == end_of_epoch_ckpt_v1["epoch"] + assert end_of_epoch_ckpt["global_step"] == end_of_epoch_ckpt_v1["global_step"] + assert compare_state_dicts(end_of_epoch_ckpt["state_dict"], end_of_epoch_ckpt_v1["state_dict"]) == {} + + +def test_restart_from_last_parity(tmp_path): + model = PredictableBoringModel() + checkpoint_callback = ModelCheckpoint( + dirpath=tmp_path, + save_last=True, + save_top_k=-1, + ) + + trainer = Trainer( + default_root_dir=tmp_path, + limit_train_batches=2, + max_epochs=4, + callbacks=[checkpoint_callback], + logger=False, + enable_model_summary=False, + enable_progress_bar=False, + ) + trainer.fit(model) + + last_ckpt_1 = torch.load(str(tmp_path / "last.ckpt"), weights_only=True) + + trainer = Trainer( + default_root_dir=tmp_path, + limit_train_batches=2, + max_epochs=2, + callbacks=[checkpoint_callback], + logger=False, + enable_model_summary=False, + enable_progress_bar=False, + ) + trainer.fit(model) + + trainer = Trainer( + default_root_dir=tmp_path, + limit_train_batches=2, + max_epochs=4, + callbacks=[checkpoint_callback], + logger=False, + enable_model_summary=False, + enable_progress_bar=False, + ) + trainer.fit(model, ckpt_path=str(tmp_path / "last.ckpt")) + + last_ckpt_2 = torch.load(str(tmp_path / "last.ckpt"), weights_only=True) + + assert compare_state_dicts(last_ckpt_1["loops"], last_ckpt_2["loops"]) == {} + + +def test_restart_from_last_with_val_parity(tmp_path): + model = PredictableBoringModel() + checkpoint_callback = ModelCheckpoint( + dirpath=tmp_path, + save_last=True, + save_top_k=-1, + ) + + trainer = Trainer( + default_root_dir=tmp_path, + limit_train_batches=2, + max_epochs=4, + callbacks=[checkpoint_callback], + logger=False, + enable_model_summary=False, + enable_progress_bar=False, + limit_val_batches=2, + val_check_interval=2, + ) + trainer.fit(model) + + last_ckpt_1 = torch.load(str(tmp_path / "last.ckpt"), weights_only=True) + + trainer = Trainer( + default_root_dir=tmp_path, + limit_train_batches=2, + max_epochs=2, + callbacks=[checkpoint_callback], + logger=False, + enable_model_summary=False, + enable_progress_bar=False, + limit_val_batches=2, + val_check_interval=2, + ) + trainer.fit(model) + + trainer = Trainer( + default_root_dir=tmp_path, + limit_train_batches=2, + max_epochs=4, + callbacks=[checkpoint_callback], + logger=False, + enable_model_summary=False, + enable_progress_bar=False, + limit_val_batches=2, + val_check_interval=2, + ) + trainer.fit(model, ckpt_path=str(tmp_path / "last.ckpt")) + + last_ckpt_2 = torch.load(str(tmp_path / "last.ckpt"), weights_only=True) + + assert compare_state_dicts(last_ckpt_1["loops"], last_ckpt_2["loops"]) == {} @pytest.mark.parametrize( diff --git a/tests/tests_pytorch/models/test_hooks.py b/tests/tests_pytorch/models/test_hooks.py index 5a175e181dd9e..1a8aeb4b297a9 100644 --- a/tests/tests_pytorch/models/test_hooks.py +++ b/tests/tests_pytorch/models/test_hooks.py @@ -61,7 +61,7 @@ def on_before_zero_grad(self, optimizer): model = CurrentTestModel() - trainer = Trainer(default_root_dir=tmp_path, max_steps=max_steps, max_epochs=2) + trainer = Trainer(devices=1, default_root_dir=tmp_path, max_steps=max_steps, max_epochs=2) assert model.on_before_zero_grad_called == 0 trainer.fit(model) assert max_steps == model.on_before_zero_grad_called @@ -406,7 +406,7 @@ def prepare_data(self): ... @pytest.mark.parametrize( "kwargs", [ - {}, + {"devices": 1}, # these precision plugins modify the optimization flow, so testing them explicitly pytest.param({"accelerator": "gpu", "devices": 1, "precision": "16-mixed"}, marks=RunIf(min_cuda_gpus=1)), pytest.param( @@ -528,6 +528,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmp_path): # initial training to get a checkpoint model = BoringModel() trainer = Trainer( + devices=1, default_root_dir=tmp_path, max_epochs=1, limit_train_batches=2, @@ -543,6 +544,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmp_path): callback = HookedCallback(called) # already performed 1 step, resume and do 2 more trainer = Trainer( + devices=1, default_root_dir=tmp_path, max_epochs=2, limit_train_batches=2, @@ -605,6 +607,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmp_path): # initial training to get a checkpoint model = BoringModel() trainer = Trainer( + devices=1, default_root_dir=tmp_path, max_steps=1, limit_val_batches=0, @@ -624,6 +627,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmp_path): train_batches = 2 steps_after_reload = 1 + train_batches trainer = Trainer( + devices=1, default_root_dir=tmp_path, max_steps=steps_after_reload, limit_val_batches=0, @@ -660,8 +664,6 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmp_path): {"name": "train_dataloader"}, {"name": "Callback.on_train_start", "args": (trainer, model)}, {"name": "on_train_start"}, - {"name": "Callback.on_train_epoch_start", "args": (trainer, model)}, - {"name": "on_train_epoch_start"}, *model._train_batch(trainer, model, steps_after_reload, trainer.strategy.root_device, current_batch=1), {"name": "Callback.on_train_epoch_end", "args": (trainer, model)}, {"name": "on_train_epoch_end"}, # before ModelCheckpoint because it's a "monitoring callback" @@ -692,6 +694,7 @@ def test_trainer_model_hook_system_eval(tmp_path, override_on_x_model_train, bat assert is_overridden(f"on_{noun}_model_train", model) == override_on_x_model_train callback = HookedCallback(called) trainer = Trainer( + devices=1, default_root_dir=tmp_path, max_epochs=1, limit_val_batches=batches, @@ -733,7 +736,11 @@ def test_trainer_model_hook_system_predict(tmp_path): callback = HookedCallback(called) batches = 2 trainer = Trainer( - default_root_dir=tmp_path, limit_predict_batches=batches, enable_progress_bar=False, callbacks=[callback] + devices=1, + default_root_dir=tmp_path, + limit_predict_batches=batches, + enable_progress_bar=False, + callbacks=[callback], ) trainer.predict(model) expected = [ @@ -799,7 +806,7 @@ def predict_dataloader(self): model = CustomBoringModel() - trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=5) + trainer = Trainer(devices=1, default_root_dir=tmp_path, fast_dev_run=5) trainer.fit(model) trainer.test(model) @@ -814,6 +821,7 @@ def test_trainer_datamodule_hook_system(tmp_path): model = BoringModel() batches = 2 trainer = Trainer( + devices=1, default_root_dir=tmp_path, max_epochs=1, limit_train_batches=batches, @@ -889,7 +897,7 @@ class CustomHookedModel(HookedModel): assert is_overridden("configure_model", model) == override_configure_model datamodule = CustomHookedDataModule(ldm_called) - trainer = Trainer() + trainer = Trainer(devices=1) trainer.strategy.connect(model) trainer._data_connector.attach_data(model, datamodule=datamodule) ckpt_path = str(tmp_path / "file.ckpt") @@ -962,6 +970,7 @@ def predict_step(self, *args, **kwargs): model = MixedTrainModeModule() trainer = Trainer( + devices=1, default_root_dir=tmp_path, max_epochs=1, val_check_interval=1, diff --git a/tests/tests_pytorch/models/test_restore.py b/tests/tests_pytorch/models/test_restore.py index fe7e3fbbab357..64f70b176a971 100644 --- a/tests/tests_pytorch/models/test_restore.py +++ b/tests/tests_pytorch/models/test_restore.py @@ -15,8 +15,9 @@ import logging as log import os import pickle +from collections.abc import Mapping from copy import deepcopy -from typing import Generic, Mapping, TypeVar +from typing import Generic, TypeVar import cloudpickle import pytest diff --git a/tests/tests_pytorch/overrides/test_distributed.py b/tests/tests_pytorch/overrides/test_distributed.py index 29eb6d6d6d511..3e2fba54bcd03 100644 --- a/tests/tests_pytorch/overrides/test_distributed.py +++ b/tests/tests_pytorch/overrides/test_distributed.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable +from collections.abc import Iterable import pytest import torch diff --git a/tests/tests_pytorch/plugins/precision/test_fsdp.py b/tests/tests_pytorch/plugins/precision/test_fsdp.py index 8b595c2c74a32..3ad3af1f1b56b 100644 --- a/tests/tests_pytorch/plugins/precision/test_fsdp.py +++ b/tests/tests_pytorch/plugins/precision/test_fsdp.py @@ -40,6 +40,24 @@ def test_fsdp_precision_config(precision, expected): assert config.reduce_dtype == expected[2] +@pytest.mark.parametrize( + ("precision", "expected_dtype"), + [ + ("32-true", torch.float32), + ("bf16-mixed", torch.float32), + ("16-mixed", torch.float32), + ("bf16-true", torch.bfloat16), + ("16-true", torch.float16), + ], +) +def test_convert_module(precision, expected_dtype): + precision = FSDPPrecision(precision=precision) + module = torch.nn.Linear(2, 2) + assert module.weight.dtype == module.bias.dtype == torch.float32 + module = precision.convert_module(module) + assert module.weight.dtype == module.bias.dtype == expected_dtype + + def test_fsdp_precision_default_scaler(): from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler diff --git a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py index 185a767d9e8c9..58baa47e7a620 100644 --- a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py +++ b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py @@ -13,7 +13,7 @@ # limitations under the License. import os from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any, Optional from unittest.mock import MagicMock, Mock import torch @@ -27,10 +27,10 @@ class CustomCheckpointIO(CheckpointIO): - def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: + def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: torch.save(checkpoint, path) - def load_checkpoint(self, path: _PATH, storage_options: Optional[Any] = None) -> Dict[str, Any]: + def load_checkpoint(self, path: _PATH, storage_options: Optional[Any] = None) -> dict[str, Any]: return torch.load(path, weights_only=True) def remove_checkpoint(self, path: _PATH) -> None: diff --git a/tests/tests_pytorch/run_standalone_tasks.sh b/tests/tests_pytorch/run_standalone_tasks.sh index eb2f7f6b22d86..48bc920adab0a 100644 --- a/tests/tests_pytorch/run_standalone_tasks.sh +++ b/tests/tests_pytorch/run_standalone_tasks.sh @@ -21,7 +21,13 @@ export PL_RUN_STANDALONE_TESTS=1 # test that a user can manually launch individual processes echo "Running manual ddp launch test" export PYTHONPATH="${PYTHONPATH}:$(pwd)" -args="fit --trainer.accelerator gpu --trainer.devices 2 --trainer.strategy ddp --trainer.max_epochs=1 --trainer.limit_train_batches=1 --trainer.limit_val_batches=1 --trainer.limit_test_batches=1" +args="fit --trainer.accelerator gpu \ +--trainer.devices 2 \ +--trainer.strategy ddp \ +--trainer.max_epochs=1 \ +--trainer.limit_train_batches=1 \ +--trainer.limit_val_batches=1 \ +--trainer.limit_test_batches=1" MASTER_ADDR="localhost" MASTER_PORT=1234 LOCAL_RANK=1 python strategies/scripts/cli_script.py ${args} & MASTER_ADDR="localhost" MASTER_PORT=1234 LOCAL_RANK=0 python strategies/scripts/cli_script.py ${args} diff --git a/tests/tests_pytorch/serve/test_servable_module_validator.py b/tests/tests_pytorch/serve/test_servable_module_validator.py index 7c883c419ea82..ec4dd8825c8ea 100644 --- a/tests/tests_pytorch/serve/test_servable_module_validator.py +++ b/tests/tests_pytorch/serve/test_servable_module_validator.py @@ -1,5 +1,3 @@ -from typing import Dict - import pytest import torch from lightning.pytorch import Trainer @@ -21,7 +19,7 @@ def serialize(x): return {"x": deserialize}, {"output": serialize} - def serve_step(self, x: Tensor) -> Dict[str, Tensor]: + def serve_step(self, x: Tensor) -> dict[str, Tensor]: assert torch.equal(x, torch.arange(32, dtype=torch.float)) return {"output": torch.tensor([0, 1])} diff --git a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py index b0462c0105a9f..394d827058987 100644 --- a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py +++ b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py @@ -209,10 +209,13 @@ def test_memory_sharing_disabled(tmp_path): def test_check_for_missing_main_guard(): launcher = _MultiProcessingLauncher(strategy=Mock(), start_method="spawn") - with mock.patch( - "lightning.pytorch.strategies.launchers.multiprocessing.mp.current_process", - return_value=Mock(_inheriting=True), # pretend that main is importing itself - ), pytest.raises(RuntimeError, match="requires that your script guards the main"): + with ( + mock.patch( + "lightning.pytorch.strategies.launchers.multiprocessing.mp.current_process", + return_value=Mock(_inheriting=True), # pretend that main is importing itself + ), + pytest.raises(RuntimeError, match="requires that your script guards the main"), + ): launcher.launch(function=Mock()) diff --git a/tests/tests_pytorch/strategies/test_custom_strategy.py b/tests/tests_pytorch/strategies/test_custom_strategy.py index 7f7d018f634e1..347dacbd9a811 100644 --- a/tests/tests_pytorch/strategies/test_custom_strategy.py +++ b/tests/tests_pytorch/strategies/test_custom_strategy.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any, Mapping +from collections.abc import Mapping +from typing import Any import pytest import torch diff --git a/tests/tests_pytorch/strategies/test_deepspeed.py b/tests/tests_pytorch/strategies/test_deepspeed.py index be9428ff7533c..73697ea131545 100644 --- a/tests/tests_pytorch/strategies/test_deepspeed.py +++ b/tests/tests_pytorch/strategies/test_deepspeed.py @@ -15,7 +15,7 @@ import json import os from re import escape -from typing import Any, Dict +from typing import Any from unittest import mock from unittest.mock import ANY, Mock @@ -48,7 +48,7 @@ def configure_model(self) -> None: if self.layer is None: self.layer = torch.nn.Linear(32, 2) - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None: self.configure_model() @@ -73,7 +73,7 @@ def configure_model(self) -> None: if self.layer is None: self.layer = torch.nn.Linear(32, 2) - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None: self.configure_model() @property @@ -623,7 +623,7 @@ def configure_optimizers(self): lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99) return [optimizer], [{"scheduler": lr_scheduler, "interval": "step"}] - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None: if not hasattr(self, "model"): self.configure_model() diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index aec01b83e956a..2aee68f7ae733 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -444,9 +444,12 @@ def __init__(self): strategy._parallel_devices = [torch.device("cuda", 0)] strategy._lightning_module = model strategy._process_group = Mock() - with mock.patch("torch.distributed.fsdp.FullyShardedDataParallel", new=MagicMock), mock.patch( - "torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing" - ) as apply_mock: + with ( + mock.patch("torch.distributed.fsdp.FullyShardedDataParallel", new=MagicMock), + mock.patch( + "torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing" + ) as apply_mock, + ): wrapped = strategy._setup_model(model) apply_mock.assert_called_with(wrapped, checkpoint_wrapper_fn=ANY, **strategy._activation_checkpointing_kwargs) diff --git a/tests/tests_pytorch/strategies/test_model_parallel_integration.py b/tests/tests_pytorch/strategies/test_model_parallel_integration.py index 57d273917573a..9dcbcc802834b 100644 --- a/tests/tests_pytorch/strategies/test_model_parallel_integration.py +++ b/tests/tests_pytorch/strategies/test_model_parallel_integration.py @@ -78,10 +78,26 @@ def _parallelize_feed_forward_fsdp2_tp(model, device_mesh): return model +def _parallelize_with_compile(parallelize): + def fn(model, device_mesh): + model = parallelize(model, device_mesh) + return torch.compile(model) + + return fn + + +@pytest.fixture() +def distributed(): + yield + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + class TemplateModel(LightningModule): - def __init__(self): + def __init__(self, compile=False): super().__init__() self.model = FeedForward() + self._compile = compile def training_step(self, batch): output = self.model(batch) @@ -98,21 +114,30 @@ def configure_optimizers(self): class FSDP2Model(TemplateModel): def configure_model(self): - _parallelize_feed_forward_fsdp2(self.model, device_mesh=self.device_mesh) + parallelize = _parallelize_feed_forward_fsdp2_tp + if self._compile: + parallelize = _parallelize_with_compile(parallelize) + parallelize(self.model, device_mesh=self.device_mesh) class TensorParallelModel(TemplateModel): def configure_model(self): - _parallelize_feed_forward_tp(self.model, device_mesh=self.device_mesh) + parallelize = _parallelize_feed_forward_tp + if self._compile: + parallelize = _parallelize_with_compile(parallelize) + parallelize(self.model, device_mesh=self.device_mesh) class FSDP2TensorParallelModel(TemplateModel): def configure_model(self): - _parallelize_feed_forward_fsdp2_tp(self.model, device_mesh=self.device_mesh) + parallelize = _parallelize_feed_forward_fsdp2_tp + if self._compile: + parallelize = _parallelize_with_compile(parallelize) + parallelize(self.model, device_mesh=self.device_mesh) @RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4) -def test_setup_device_mesh(): +def test_setup_device_mesh(distributed): from torch.distributed.device_mesh import DeviceMesh for dp_size, tp_size in ((1, 4), (4, 1), (2, 2)): @@ -169,7 +194,11 @@ def configure_model(self): @RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=2) -def test_tensor_parallel(): +@pytest.mark.parametrize( + "compile", + [True, False], +) +def test_tensor_parallel(distributed, compile): from torch.distributed._tensor import DTensor class Model(TensorParallelModel): @@ -204,13 +233,17 @@ def training_step(self, batch): seed_everything(0) with trainer.init_module(empty_init=True): - model = Model() + model = Model(compile=compile) trainer.fit(model) @RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4) -def test_fsdp2_tensor_parallel(): +@pytest.mark.parametrize( + "compile", + [True, False], +) +def test_fsdp2_tensor_parallel(distributed, compile): from torch.distributed._tensor import DTensor class Model(FSDP2TensorParallelModel): @@ -261,13 +294,13 @@ def training_step(self, batch): seed_everything(0) with trainer.init_module(empty_init=True): - model = Model() + model = Model(compile=compile) trainer.fit(model) @RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True) -def test_modules_without_parameters(tmp_path): +def test_modules_without_parameters(distributed, tmp_path): """Test that TorchMetrics get moved to the device despite not having any parameters.""" class MetricsModel(TensorParallelModel): @@ -306,7 +339,11 @@ def training_step(self, batch): pytest.param("bf16-true", torch.bfloat16, marks=RunIf(bf16_cuda=True)), ], ) -def test_module_init_context(precision, expected_dtype, tmp_path): +@pytest.mark.parametrize( + "compile", + [True, False], +) +def test_module_init_context(distributed, compile, precision, expected_dtype, tmp_path): """Test that the module under the init-context gets moved to the right device and dtype.""" class Model(FSDP2Model): @@ -329,7 +366,7 @@ def _run_setup_assertions(empty_init, expected_device): logger=False, ) with trainer.init_module(empty_init=empty_init): - model = Model() + model = Model(compile=compile) # The model is on the CPU/meta-device until after `ModelParallelStrategy.setup()` assert model.model.w1.weight.device == expected_device @@ -345,7 +382,7 @@ def _run_setup_assertions(empty_init, expected_device): @RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True) @pytest.mark.parametrize("save_distributed_checkpoint", [True, False]) -def test_strategy_state_dict(tmp_path, save_distributed_checkpoint): +def test_strategy_state_dict(distributed, tmp_path, save_distributed_checkpoint): """Test that the strategy returns the correct state dict of the LightningModule.""" model = FSDP2Model() correct_state_dict = model.state_dict() # State dict before wrapping @@ -378,7 +415,7 @@ def test_strategy_state_dict(tmp_path, save_distributed_checkpoint): @RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True) -def test_load_full_state_checkpoint_into_regular_model(tmp_path): +def test_load_full_state_checkpoint_into_regular_model(distributed, tmp_path): """Test that a full-state checkpoint saved from a distributed model can be loaded back into a regular model.""" # Save a regular full-state checkpoint from a distributed model @@ -420,7 +457,7 @@ def test_load_full_state_checkpoint_into_regular_model(tmp_path): @pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True) -def test_load_standard_checkpoint_into_distributed_model(tmp_path): +def test_load_standard_checkpoint_into_distributed_model(distributed, tmp_path): """Test that a regular checkpoint (weights and optimizer states) can be loaded into a distributed model.""" # Save a regular DDP checkpoint @@ -461,7 +498,7 @@ def test_load_standard_checkpoint_into_distributed_model(tmp_path): @pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True) -def test_save_load_sharded_state_dict(tmp_path): +def test_save_load_sharded_state_dict(distributed, tmp_path): """Test saving and loading with the distributed state dict format.""" class CheckpointModel(FSDP2Model): diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index 56b58d4d157a1..7341d33292a58 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -20,7 +20,7 @@ from contextlib import ExitStack, contextmanager, redirect_stdout from io import StringIO from pathlib import Path -from typing import Callable, List, Optional, Union +from typing import Callable, Optional, Union from unittest import mock from unittest.mock import ANY @@ -127,7 +127,7 @@ def _model_builder(model_param: int) -> Model: def _trainer_builder( - limit_train_batches: int, fast_dev_run: bool = False, callbacks: Optional[Union[List[Callback], Callback]] = None + limit_train_batches: int, fast_dev_run: bool = False, callbacks: Optional[Union[list[Callback], Callback]] = None ) -> Trainer: return Trainer(limit_train_batches=limit_train_batches, fast_dev_run=fast_dev_run, callbacks=callbacks) @@ -409,8 +409,9 @@ def test_lightning_cli_config_and_subclass_mode(cleandir): with open(config_path, "w") as f: f.write(yaml.dump(input_config)) - with mock.patch("sys.argv", ["any.py", "--config", config_path]), mock_subclasses( - LightningDataModule, DataDirDataModule + with ( + mock.patch("sys.argv", ["any.py", "--config", config_path]), + mock_subclasses(LightningDataModule, DataDirDataModule), ): cli = LightningCLI( BoringModel, @@ -461,9 +462,12 @@ def test_lightning_cli_help(): cli_args = ["any.py", "fit", "--data.help=DataDirDataModule"] out = StringIO() - with mock.patch("sys.argv", cli_args), redirect_stdout(out), mock_subclasses( - LightningDataModule, DataDirDataModule - ), pytest.raises(SystemExit): + with ( + mock.patch("sys.argv", cli_args), + redirect_stdout(out), + mock_subclasses(LightningDataModule, DataDirDataModule), + pytest.raises(SystemExit), + ): any_model_any_data_cli() assert ("--data.data_dir" in out.getvalue()) or ("--data.init_args.data_dir" in out.getvalue()) @@ -474,8 +478,8 @@ def test_lightning_cli_print_config(): "any.py", "predict", "--seed_everything=1234", - "--model=lightning.pytorch.demos.boring_classes.BoringModel", - "--data=lightning.pytorch.demos.boring_classes.BoringDataModule", + "--model=lightning.pytorch.demos.BoringModel", + "--data=lightning.pytorch.demos.BoringDataModule", "--print_config", ] out = StringIO() @@ -488,8 +492,8 @@ def test_lightning_cli_print_config(): outval = yaml.safe_load(text) assert outval["seed_everything"] == 1234 - assert outval["model"]["class_path"] == "lightning.pytorch.demos.boring_classes.BoringModel" - assert outval["data"]["class_path"] == "lightning.pytorch.demos.boring_classes.BoringDataModule" + assert outval["model"]["class_path"] == "lightning.pytorch.demos.BoringModel" + assert outval["data"]["class_path"] == "lightning.pytorch.demos.BoringDataModule" assert outval["ckpt_path"] is None @@ -522,7 +526,7 @@ def __init__(self, submodule1: LightningModule, submodule2: LightningModule, mai @pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason=str(_TORCHVISION_AVAILABLE)) def test_lightning_cli_torch_modules(cleandir): class TestModule(BoringModel): - def __init__(self, activation: torch.nn.Module = None, transform: Optional[List[torch.nn.Module]] = None): + def __init__(self, activation: torch.nn.Module = None, transform: Optional[list[torch.nn.Module]] = None): super().__init__() self.activation = activation self.transform = transform @@ -609,8 +613,9 @@ def on_fit_start(self): def test_cli_distributed_save_config_callback(cleandir, logger, strategy): from torch.multiprocessing import ProcessRaisedException - with mock.patch("sys.argv", ["any.py", "fit"]), pytest.raises( - (MisconfigurationException, ProcessRaisedException), match=r"Error on fit start" + with ( + mock.patch("sys.argv", ["any.py", "fit"]), + pytest.raises((MisconfigurationException, ProcessRaisedException), match=r"Error on fit start"), ): LightningCLI( EarlyExitTestModel, @@ -710,12 +715,14 @@ def train_dataloader(self): ... from lightning.pytorch.trainer.configuration_validator import __verify_train_val_loop_configuration - with mock.patch("sys.argv", ["any.py", "fit", "--optimizer=Adam"]), mock.patch( - "lightning.pytorch.Trainer._run_stage" - ) as run, mock.patch( - "lightning.pytorch.trainer.configuration_validator.__verify_train_val_loop_configuration", - wraps=__verify_train_val_loop_configuration, - ) as verify: + with ( + mock.patch("sys.argv", ["any.py", "fit", "--optimizer=Adam"]), + mock.patch("lightning.pytorch.Trainer._run_stage") as run, + mock.patch( + "lightning.pytorch.trainer.configuration_validator.__verify_train_val_loop_configuration", + wraps=__verify_train_val_loop_configuration, + ) as verify, + ): cli = LightningCLI(BoringModel) run.assert_called_once() verify.assert_called_once_with(cli.trainer, cli.model) @@ -871,18 +878,27 @@ def test_lightning_cli_load_from_checkpoint_dependency_injection(cleandir): hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml" assert hparams_path.is_file() hparams = yaml.safe_load(hparams_path.read_text()) - expected = { - "_instantiator": "lightning.pytorch.cli.instantiate_module", - "optimizer": "torch.optim.Adam", - "scheduler": "torch.optim.lr_scheduler.ConstantLR", - "activation": {"class_path": "torch.nn.LeakyReLU", "init_args": {"negative_slope": 0.05, "inplace": False}}, - } - assert hparams == expected + + expected_keys = ["_instantiator", "activation", "optimizer", "scheduler"] + expected_instantiator = "lightning.pytorch.cli.instantiate_module" + expected_activation = "torch.nn.LeakyReLU" + expected_optimizer = "torch.optim.Adam" + expected_scheduler = "torch.optim.lr_scheduler.ConstantLR" + + assert sorted(hparams.keys()) == expected_keys + assert hparams["_instantiator"] == expected_instantiator + assert hparams["activation"]["class_path"] == expected_activation + assert hparams["optimizer"] == expected_optimizer or hparams["optimizer"]["class_path"] == expected_optimizer + assert hparams["scheduler"] == expected_scheduler or hparams["scheduler"]["class_path"] == expected_scheduler checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"), None) assert checkpoint_path.is_file() - ckpt = torch.load(checkpoint_path, weights_only=True) - assert ckpt["hyper_parameters"] == expected + hparams = torch.load(checkpoint_path, weights_only=True)["hyper_parameters"] + assert sorted(hparams.keys()) == expected_keys + assert hparams["_instantiator"] == expected_instantiator + assert hparams["activation"]["class_path"] == expected_activation + assert hparams["optimizer"] == expected_optimizer or hparams["optimizer"]["class_path"] == expected_optimizer + assert hparams["scheduler"] == expected_scheduler or hparams["scheduler"]["class_path"] == expected_scheduler model = TestModelSaveHparams.load_from_checkpoint(checkpoint_path) assert isinstance(model, TestModelSaveHparams) @@ -898,18 +914,23 @@ def test_lightning_cli_load_from_checkpoint_dependency_injection_subclass_mode(c cli = LightningCLI(TestModelSaveHparams, run=False, auto_configure_optimizers=False, subclass_mode_model=True) cli.trainer.fit(cli.model) - expected = { - "_instantiator": "lightning.pytorch.cli.instantiate_module", - "_class_path": f"{__name__}.TestModelSaveHparams", - "optimizer": "torch.optim.Adam", - "scheduler": "torch.optim.lr_scheduler.ConstantLR", - "activation": {"class_path": "torch.nn.LeakyReLU", "init_args": {"negative_slope": 0.05, "inplace": False}}, - } + expected_keys = ["_class_path", "_instantiator", "activation", "optimizer", "scheduler"] + expected_instantiator = "lightning.pytorch.cli.instantiate_module" + expected_class_path = f"{__name__}.TestModelSaveHparams" + expected_activation = "torch.nn.LeakyReLU" + expected_optimizer = "torch.optim.Adam" + expected_scheduler = "torch.optim.lr_scheduler.ConstantLR" checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"), None) assert checkpoint_path.is_file() - ckpt = torch.load(checkpoint_path, weights_only=True) - assert ckpt["hyper_parameters"] == expected + hparams = torch.load(checkpoint_path, weights_only=True)["hyper_parameters"] + + assert sorted(hparams.keys()) == expected_keys + assert hparams["_instantiator"] == expected_instantiator + assert hparams["_class_path"] == expected_class_path + assert hparams["activation"]["class_path"] == expected_activation + assert hparams["optimizer"] == expected_optimizer or hparams["optimizer"]["class_path"] == expected_optimizer + assert hparams["scheduler"] == expected_scheduler or hparams["scheduler"]["class_path"] == expected_scheduler model = LightningModule.load_from_checkpoint(checkpoint_path) assert isinstance(model, TestModelSaveHparams) @@ -952,6 +973,29 @@ def test_lightning_cli_save_hyperparameters_untyped_module(cleandir): assert model.kwargs == {"x": 1} +class TestDataSaveHparams(BoringDataModule): + def __init__(self, batch_size: int = 32, num_workers: int = 4): + super().__init__() + self.save_hyperparameters() + self.batch_size = batch_size + self.num_workers = num_workers + + +def test_lightning_cli_save_hyperparameters_merge(cleandir): + config = { + "model": { + "class_path": f"{__name__}.TestModelSaveHparams", + }, + "data": { + "class_path": f"{__name__}.TestDataSaveHparams", + }, + } + with mock.patch("sys.argv", ["any.py", "fit", f"--config={json.dumps(config)}", "--trainer.max_epochs=1"]): + cli = LightningCLI(auto_configure_optimizers=False) + assert set(cli.model.hparams) == {"optimizer", "scheduler", "activation", "_instantiator", "_class_path"} + assert set(cli.datamodule.hparams) == {"batch_size", "num_workers", "_instantiator", "_class_path"} + + @pytest.mark.parametrize("fn", [fn.value for fn in TrainerFn]) def test_lightning_cli_trainer_fn(fn): class TestCLI(LightningCLI): @@ -1074,15 +1118,18 @@ def __init__(self, foo, bar=5): @_xfail_python_ge_3_11_9 def test_lightning_cli_model_short_arguments(): - with mock.patch("sys.argv", ["any.py", "fit", "--model=BoringModel"]), mock.patch( - "lightning.pytorch.Trainer._fit_impl" - ) as run, mock_subclasses(LightningModule, BoringModel, TestModel): + with ( + mock.patch("sys.argv", ["any.py", "fit", "--model=BoringModel"]), + mock.patch("lightning.pytorch.Trainer._fit_impl") as run, + mock_subclasses(LightningModule, BoringModel, TestModel), + ): cli = LightningCLI(trainer_defaults={"fast_dev_run": 1}) assert isinstance(cli.model, BoringModel) run.assert_called_once_with(cli.model, ANY, ANY, ANY, ANY) - with mock.patch("sys.argv", ["any.py", "--model=TestModel", "--model.foo", "123"]), mock_subclasses( - LightningModule, BoringModel, TestModel + with ( + mock.patch("sys.argv", ["any.py", "--model=TestModel", "--model.foo", "123"]), + mock_subclasses(LightningModule, BoringModel, TestModel), ): cli = LightningCLI(run=False) assert isinstance(cli.model, TestModel) @@ -1100,15 +1147,18 @@ def __init__(self, foo, bar=5): @_xfail_python_ge_3_11_9 def test_lightning_cli_datamodule_short_arguments(): # with set model - with mock.patch("sys.argv", ["any.py", "fit", "--data=BoringDataModule"]), mock.patch( - "lightning.pytorch.Trainer._fit_impl" - ) as run, mock_subclasses(LightningDataModule, BoringDataModule): + with ( + mock.patch("sys.argv", ["any.py", "fit", "--data=BoringDataModule"]), + mock.patch("lightning.pytorch.Trainer._fit_impl") as run, + mock_subclasses(LightningDataModule, BoringDataModule), + ): cli = LightningCLI(BoringModel, trainer_defaults={"fast_dev_run": 1}) assert isinstance(cli.datamodule, BoringDataModule) run.assert_called_once_with(ANY, ANY, ANY, cli.datamodule, ANY) - with mock.patch("sys.argv", ["any.py", "--data=MyDataModule", "--data.foo", "123"]), mock_subclasses( - LightningDataModule, MyDataModule + with ( + mock.patch("sys.argv", ["any.py", "--data=MyDataModule", "--data.foo", "123"]), + mock_subclasses(LightningDataModule, MyDataModule), ): cli = LightningCLI(BoringModel, run=False) assert isinstance(cli.datamodule, MyDataModule) @@ -1116,17 +1166,22 @@ def test_lightning_cli_datamodule_short_arguments(): assert cli.datamodule.bar == 5 # with configurable model - with mock.patch("sys.argv", ["any.py", "fit", "--model", "BoringModel", "--data=BoringDataModule"]), mock.patch( - "lightning.pytorch.Trainer._fit_impl" - ) as run, mock_subclasses(LightningModule, BoringModel), mock_subclasses(LightningDataModule, BoringDataModule): + with ( + mock.patch("sys.argv", ["any.py", "fit", "--model", "BoringModel", "--data=BoringDataModule"]), + mock.patch("lightning.pytorch.Trainer._fit_impl") as run, + mock_subclasses(LightningModule, BoringModel), + mock_subclasses(LightningDataModule, BoringDataModule), + ): cli = LightningCLI(trainer_defaults={"fast_dev_run": 1}) assert isinstance(cli.model, BoringModel) assert isinstance(cli.datamodule, BoringDataModule) run.assert_called_once_with(cli.model, ANY, ANY, cli.datamodule, ANY) - with mock.patch("sys.argv", ["any.py", "--model", "BoringModel", "--data=MyDataModule"]), mock_subclasses( - LightningModule, BoringModel - ), mock_subclasses(LightningDataModule, MyDataModule): + with ( + mock.patch("sys.argv", ["any.py", "--model", "BoringModel", "--data=MyDataModule"]), + mock_subclasses(LightningModule, BoringModel), + mock_subclasses(LightningDataModule, MyDataModule), + ): cli = LightningCLI(run=False) assert isinstance(cli.model, BoringModel) assert isinstance(cli.datamodule, MyDataModule) @@ -1293,9 +1348,10 @@ def __init__(self, opt1_config: dict, opt2_config: dict, sch_config: dict): def test_lightning_cli_config_with_subcommand(): config = {"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}} - with mock.patch("sys.argv", ["any.py", f"--config={config}"]), mock.patch( - "lightning.pytorch.Trainer.test", autospec=True - ) as test_mock: + with ( + mock.patch("sys.argv", ["any.py", f"--config={config}"]), + mock.patch("lightning.pytorch.Trainer.test", autospec=True) as test_mock, + ): cli = LightningCLI(BoringModel) test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar") @@ -1308,9 +1364,10 @@ def test_lightning_cli_config_before_subcommand(): "test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}, } - with mock.patch("sys.argv", ["any.py", f"--config={config}", "test"]), mock.patch( - "lightning.pytorch.Trainer.test", autospec=True - ) as test_mock: + with ( + mock.patch("sys.argv", ["any.py", f"--config={config}", "test"]), + mock.patch("lightning.pytorch.Trainer.test", autospec=True) as test_mock, + ): cli = LightningCLI(BoringModel) test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar") @@ -1320,9 +1377,10 @@ def test_lightning_cli_config_before_subcommand(): assert save_config_callback.config.trainer.limit_test_batches == 1 assert save_config_callback.parser.subcommand == "test" - with mock.patch("sys.argv", ["any.py", f"--config={config}", "validate"]), mock.patch( - "lightning.pytorch.Trainer.validate", autospec=True - ) as validate_mock: + with ( + mock.patch("sys.argv", ["any.py", f"--config={config}", "validate"]), + mock.patch("lightning.pytorch.Trainer.validate", autospec=True) as validate_mock, + ): cli = LightningCLI(BoringModel) validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo") @@ -1337,17 +1395,19 @@ def test_lightning_cli_config_before_subcommand_two_configs(): config1 = {"validate": {"trainer": {"limit_val_batches": 1}, "verbose": False, "ckpt_path": "barfoo"}} config2 = {"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}} - with mock.patch("sys.argv", ["any.py", f"--config={config1}", f"--config={config2}", "test"]), mock.patch( - "lightning.pytorch.Trainer.test", autospec=True - ) as test_mock: + with ( + mock.patch("sys.argv", ["any.py", f"--config={config1}", f"--config={config2}", "test"]), + mock.patch("lightning.pytorch.Trainer.test", autospec=True) as test_mock, + ): cli = LightningCLI(BoringModel) test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar") assert cli.trainer.limit_test_batches == 1 - with mock.patch("sys.argv", ["any.py", f"--config={config1}", f"--config={config2}", "validate"]), mock.patch( - "lightning.pytorch.Trainer.validate", autospec=True - ) as validate_mock: + with ( + mock.patch("sys.argv", ["any.py", f"--config={config1}", f"--config={config2}", "validate"]), + mock.patch("lightning.pytorch.Trainer.validate", autospec=True) as validate_mock, + ): cli = LightningCLI(BoringModel) validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo") @@ -1356,9 +1416,10 @@ def test_lightning_cli_config_before_subcommand_two_configs(): def test_lightning_cli_config_after_subcommand(): config = {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"} - with mock.patch("sys.argv", ["any.py", "test", f"--config={config}"]), mock.patch( - "lightning.pytorch.Trainer.test", autospec=True - ) as test_mock: + with ( + mock.patch("sys.argv", ["any.py", "test", f"--config={config}"]), + mock.patch("lightning.pytorch.Trainer.test", autospec=True) as test_mock, + ): cli = LightningCLI(BoringModel) test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar") @@ -1368,9 +1429,10 @@ def test_lightning_cli_config_after_subcommand(): def test_lightning_cli_config_before_and_after_subcommand(): config1 = {"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}} config2 = {"trainer": {"fast_dev_run": 1}, "verbose": False, "ckpt_path": "foobar"} - with mock.patch("sys.argv", ["any.py", f"--config={config1}", "test", f"--config={config2}"]), mock.patch( - "lightning.pytorch.Trainer.test", autospec=True - ) as test_mock: + with ( + mock.patch("sys.argv", ["any.py", f"--config={config1}", "test", f"--config={config2}"]), + mock.patch("lightning.pytorch.Trainer.test", autospec=True) as test_mock, + ): cli = LightningCLI(BoringModel) test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=False, ckpt_path="foobar") @@ -1392,17 +1454,19 @@ def test_lightning_cli_parse_kwargs_with_subcommands(cleandir): "validate": {"default_config_files": [str(validate_config_path)]}, } - with mock.patch("sys.argv", ["any.py", "fit"]), mock.patch( - "lightning.pytorch.Trainer.fit", autospec=True - ) as fit_mock: + with ( + mock.patch("sys.argv", ["any.py", "fit"]), + mock.patch("lightning.pytorch.Trainer.fit", autospec=True) as fit_mock, + ): cli = LightningCLI(BoringModel, parser_kwargs=parser_kwargs) fit_mock.assert_called() assert cli.trainer.limit_train_batches == 2 assert cli.trainer.limit_val_batches == 1.0 - with mock.patch("sys.argv", ["any.py", "validate"]), mock.patch( - "lightning.pytorch.Trainer.validate", autospec=True - ) as validate_mock: + with ( + mock.patch("sys.argv", ["any.py", "validate"]), + mock.patch("lightning.pytorch.Trainer.validate", autospec=True) as validate_mock, + ): cli = LightningCLI(BoringModel, parser_kwargs=parser_kwargs) validate_mock.assert_called() assert cli.trainer.limit_train_batches == 1.0 @@ -1420,9 +1484,10 @@ def __init__(self, foo: int, *args, **kwargs): config_path.write_text(str(config)) parser_kwargs = {"default_config_files": [str(config_path)]} - with mock.patch("sys.argv", ["any.py", "fit"]), mock.patch( - "lightning.pytorch.Trainer.fit", autospec=True - ) as fit_mock: + with ( + mock.patch("sys.argv", ["any.py", "fit"]), + mock.patch("lightning.pytorch.Trainer.fit", autospec=True) as fit_mock, + ): cli = LightningCLI(Model, parser_kwargs=parser_kwargs) fit_mock.assert_called() assert cli.model.foo == 123 @@ -1580,8 +1645,13 @@ def _test_logger_init_args(logger_name, init, unresolved=None): def test_comet_logger_init_args(): _test_logger_init_args( "CometLogger", - init={"save_dir": "comet"}, # Resolve from CometLogger.__init__ - unresolved={"workspace": "comet"}, # Resolve from Comet{,Existing,Offline}Experiment.__init__ + init={ + "experiment_key": "some_key", # Resolve from CometLogger.__init__ + "workspace": "comet", + }, + unresolved={ + "save_dir": "comet", # Resolve from CometLogger.__init__ as kwarg + }, ) diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py index 65c5777e28fed..9e947e0723dcd 100644 --- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py @@ -15,7 +15,7 @@ import os import sys from contextlib import nullcontext -from typing import Any, Dict +from typing import Any from unittest import mock from unittest.mock import Mock @@ -178,7 +178,7 @@ class Accel(Accelerator): def setup_device(self, device: torch.device) -> None: pass - def get_device_stats(self, device: torch.device) -> Dict[str, Any]: + def get_device_stats(self, device: torch.device) -> dict[str, Any]: pass def teardown(self) -> None: diff --git a/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py b/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py index 39911f9eddc7f..d29e2285e983c 100644 --- a/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py @@ -18,7 +18,7 @@ import pytest import torch from lightning.pytorch import Trainer -from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.callbacks import Callback, ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.trainer.states import TrainerFn from lightning.pytorch.utilities.migration.utils import _set_version @@ -234,3 +234,53 @@ def test_strict_loading(strict_loading, expected, tmp_path): trainer = Trainer(default_root_dir=tmp_path, barebones=True, max_steps=2) trainer.fit(model, ckpt_path=(tmp_path / "checkpoint.ckpt")) model.load_state_dict.assert_called_once_with(ANY, strict=expected) + + +@pytest.mark.parametrize("trainer_fn", ["validate", "test", "predict"]) +def test_restore_callbacks_in_non_fit_phases(tmp_path, trainer_fn): + """Test that callbacks are properly restored in non-fit phases.""" + + class TestCallback(Callback): + def __init__(self): + self.restored = False + + def on_load_checkpoint(self, trainer, pl_module, checkpoint): + if "callbacks" in checkpoint: + callback_state = checkpoint["callbacks"][self.__class__.__name__] + self.restored = callback_state["restored"] + + def state_dict(self): + return {"restored": self.restored} + + def on_save_checkpoint(self, trainer, pl_module, checkpoint): + checkpoint["callbacks"] = checkpoint.get("callbacks", {}) + checkpoint["callbacks"][self.__class__.__name__] = self.state_dict() + + # First create and train a model with the callback + callback = TestCallback() + model = BoringModel() + trainer = Trainer(default_root_dir=tmp_path, callbacks=[callback], max_steps=1) + trainer.fit(model) + + # Set the callback state to True before saving + callback.restored = True + ckpt_path = tmp_path / "checkpoint.ckpt" + trainer.save_checkpoint(ckpt_path) + + # Now create new instances and test restoration + new_callback = TestCallback() + new_model = BoringModel() + assert not new_callback.restored # Should start False + + new_trainer = Trainer(default_root_dir=tmp_path, callbacks=[new_callback]) + + # Connect the model and restore callbacks before evaluation + new_trainer.strategy.connect(new_model) + new_trainer._checkpoint_connector.resume_start(ckpt_path) + new_trainer._checkpoint_connector.restore_callbacks() + + # Run the evaluation phase (validate/test/predict) + fn = getattr(new_trainer, trainer_fn) + fn(new_model, ckpt_path=ckpt_path) + + assert new_callback.restored # Should be True after loading the checkpoint diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index a820a3d6ee786..ca5690ed20f41 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Sized from re import escape -from typing import Sized from unittest import mock from unittest.mock import Mock diff --git a/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py b/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py index 41d9301e4847b..af7cecdb21a08 100644 --- a/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from unittest.mock import Mock import lightning.pytorch as pl @@ -38,7 +38,7 @@ def __init__(self): def experiment(self) -> Any: return self.exp - def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): + def log_metrics(self, metrics: dict[str, float], step: Optional[int] = None): self.logs.update(metrics) def version(self) -> Union[int, str]: @@ -144,7 +144,7 @@ def __init__(self): self.buffer = {} self.logs = {} - def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: + def log_metrics(self, metrics: dict[str, float], step: Optional[int] = None) -> None: self.buffer.update(metrics) def finalize(self, status: str) -> None: diff --git a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py index b90d767a23caf..40c82bec2fd10 100644 --- a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py @@ -544,7 +544,7 @@ def test_step(self, batch, batch_idx): "valid_loss_1", } assert mock_log_metrics.mock_calls == [ - call({"hp_metric": -1}, 0), + call({"hp_metric": -1}, None), call(metrics={"train_loss": ANY, "epoch": 0}, step=0), call(metrics={"valid_loss_0_step": ANY, "valid_loss_2": ANY}, step=0), call(metrics={"valid_loss_0_step": ANY, "valid_loss_2": ANY}, step=1), diff --git a/tests/tests_pytorch/trainer/optimization/test_optimizers.py b/tests/tests_pytorch/trainer/optimization/test_optimizers.py index 451557d084dc7..ac660b6651be5 100644 --- a/tests/tests_pytorch/trainer/optimization/test_optimizers.py +++ b/tests/tests_pytorch/trainer/optimization/test_optimizers.py @@ -592,9 +592,10 @@ def configure_optimizers(self): limit_train_batches=limit_train_batches, limit_val_batches=0, ) - with mock.patch.object(CustomEpochScheduler, "step") as mock_method_epoch, mock.patch.object( - torch.optim.lr_scheduler.StepLR, "step" - ) as mock_method_step: + with ( + mock.patch.object(CustomEpochScheduler, "step") as mock_method_epoch, + mock.patch.object(torch.optim.lr_scheduler.StepLR, "step") as mock_method_step, + ): trainer.fit(model) assert mock_method_epoch.mock_calls == [call(epoch=e) for e in range(max_epochs)] diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index 8946fb4ed9481..d66f3aafee5df 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -1887,8 +1887,9 @@ def training_step(self, batch, batch_idx): model = NanModel() trainer = Trainer(default_root_dir=tmp_path, detect_anomaly=True) - with pytest.raises(RuntimeError, match=r"returned nan values in its 0th output."), pytest.warns( - UserWarning, match=r".*Error detected in.* Traceback of forward call that caused the error.*" + with ( + pytest.raises(RuntimeError, match=r"returned nan values in its 0th output."), + pytest.warns(UserWarning, match=r".*Error detected in.* Traceback of forward call that caused the error.*"), ): trainer.fit(model) @@ -2067,8 +2068,9 @@ def on_fit_start(self): raise exception trainer = Trainer(default_root_dir=tmp_path) - with mock.patch("lightning.pytorch.strategies.strategy.Strategy.on_exception") as on_exception_mock, suppress( - Exception, SystemExit + with ( + mock.patch("lightning.pytorch.strategies.strategy.Strategy.on_exception") as on_exception_mock, + suppress(Exception, SystemExit), ): trainer.fit(ExceptionModel()) on_exception_mock.assert_called_once_with(exception) diff --git a/tests/tests_pytorch/utilities/test_combined_loader.py b/tests/tests_pytorch/utilities/test_combined_loader.py index 74f5c1330a089..43a146c6eb089 100644 --- a/tests/tests_pytorch/utilities/test_combined_loader.py +++ b/tests/tests_pytorch/utilities/test_combined_loader.py @@ -13,7 +13,8 @@ # limitations under the License. import math import pickle -from typing import Any, NamedTuple, Sequence, get_args +from collections.abc import Sequence +from typing import Any, NamedTuple, get_args from unittest.mock import Mock import pytest diff --git a/tests/tests_pytorch/utilities/test_data.py b/tests/tests_pytorch/utilities/test_data.py index e9c80d95c58a4..d79e9e24383a0 100644 --- a/tests/tests_pytorch/utilities/test_data.py +++ b/tests/tests_pytorch/utilities/test_data.py @@ -12,6 +12,7 @@ from lightning.pytorch.trainer.states import RunningStage from lightning.pytorch.utilities.data import ( _get_dataloader_init_args_and_kwargs, + _is_dataloader_shuffled, _update_dataloader, extract_batch_size, has_len_all_ranks, @@ -20,7 +21,7 @@ from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning_utilities.test.warning import no_warning_call from torch import Tensor -from torch.utils.data import BatchSampler, DataLoader, RandomSampler +from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler def test_extract_batch_size(): @@ -304,6 +305,31 @@ def __init__(self, extra_arg): _ = _update_dataloader(dataloader, dataloader.sampler, mode=RunningStage.PREDICTING) +def test_batch_sampler_shuffle_setting(): + """Test whether the `shuffle` state is correctly set in the `BatchSampler`.""" + + random_sampler = RandomSampler(range(10)) + seq_sampler = SequentialSampler(range(10)) + shuffled_dataloader = DataLoader( + range(10), batch_sampler=BatchSampler(random_sampler, batch_size=2, drop_last=False) + ) + sequential_dataloader = DataLoader( + range(10), batch_sampler=BatchSampler(seq_sampler, batch_size=2, drop_last=False) + ) + + # if batch_size is 1, the pytorch init a default SequentialSampler and set BatchSampler to None + single_dataloader = DataLoader(range(10), batch_sampler=BatchSampler(seq_sampler, batch_size=1, drop_last=False)) + assert _is_dataloader_shuffled(shuffled_dataloader) + assert not _is_dataloader_shuffled(sequential_dataloader) + assert not _is_dataloader_shuffled(single_dataloader) + + # if batch_size is 1, and no batch_sampler is set, the pytorch will set BatchSampler to None + single_dataloader = DataLoader(range(10), batch_size=1) + shuffled_single_dataloader = DataLoader(range(10), batch_size=1, shuffle=True) + assert not _is_dataloader_shuffled(single_dataloader) + assert _is_dataloader_shuffled(shuffled_single_dataloader) + + @pytest.mark.parametrize("mode", [RunningStage.TRAINING, RunningStage.PREDICTING, RunningStage.TESTING]) def test_dataloader_kwargs_replacement_with_iterable_dataset(mode): """Test that DataLoader kwargs are not replaced when using Iterable Dataset.""" diff --git a/tests/tests_pytorch/utilities/test_imports.py b/tests/tests_pytorch/utilities/test_imports.py index 43a3fad916086..56ee326f076dc 100644 --- a/tests/tests_pytorch/utilities/test_imports.py +++ b/tests/tests_pytorch/utilities/test_imports.py @@ -117,6 +117,13 @@ def test_import_pytorch_lightning_with_torch_dist_unavailable(): code = dedent( """ import torch + try: + # PyTorch 2.5 relies on torch,distributed._composable.fsdp not + # existing with USE_DISTRIBUTED=0 + import torch._dynamo.variables.functions + torch._dynamo.variables.functions._fsdp_param_group = None + except ImportError: + pass # pretend torch.distributed not available for name in list(torch.distributed.__dict__.keys()): @@ -125,6 +132,11 @@ def test_import_pytorch_lightning_with_torch_dist_unavailable(): torch.distributed.is_available = lambda: False + # needed for Dynamo in PT 2.5+ compare the torch.distributed source + class _ProcessGroupStub: + pass + torch.distributed.ProcessGroup = _ProcessGroupStub + import lightning.pytorch """ ) diff --git a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py index 1bdac616e7b34..2ef1ecd4fe3e5 100644 --- a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py +++ b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py @@ -33,8 +33,9 @@ def test_upgrade_checkpoint_file_missing(tmp_path, caplog): # path to non-empty directory, but no checkpoints with matching extension file.touch() - with mock.patch("sys.argv", ["upgrade_checkpoint.py", str(tmp_path), "--extension", ".other"]), caplog.at_level( - logging.ERROR + with ( + mock.patch("sys.argv", ["upgrade_checkpoint.py", str(tmp_path), "--extension", ".other"]), + caplog.at_level(logging.ERROR), ): with pytest.raises(SystemExit): upgrade_main()